mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into 24-web-deploy-aws
This commit is contained in:
@@ -0,0 +1 @@
|
||||
* text=auto eol=lf
|
||||
@@ -60,31 +60,3 @@ In case you haven't done this, have already committed, and CI is failing, you ca
|
||||
### Deployment
|
||||
|
||||
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
|
||||
|
||||
# (Older version of the readme below)
|
||||
|
||||
## How do I start helping out?
|
||||
|
||||
Check out these pages to learn more about the project.
|
||||
|
||||
Ping Birger on discord if you want help to get started.
|
||||
|
||||
http://**discordapp.com/users/birger#6875**
|
||||
|
||||
## More information in the notion
|
||||
|
||||
https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d716fa46
|
||||
|
||||
## Code structure
|
||||
|
||||
### Bot
|
||||
|
||||
We have a folder named bot where code related to the bot lives.
|
||||
|
||||
### Backend
|
||||
|
||||
We have a backend folder for backend development of the api that the discord bot sends it information to.
|
||||
|
||||
### Website
|
||||
|
||||
We have a folder for the website, live at https://projects.laion.ai/Open-Chat-GPT/ .The website is built using Next.js
|
||||
|
||||
@@ -7,6 +7,7 @@ Create Date: ${create_date}
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Adds text labels table.
|
||||
|
||||
Revision ID: 067c4002f2d9
|
||||
Revises: 0daec5f8135f
|
||||
Create Date: 2022-12-25 17:05:21.208843
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "067c4002f2d9"
|
||||
down_revision = "0daec5f8135f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"text_labels",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["post_id"],
|
||||
["post.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("text_labels")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_journal_table
|
||||
|
||||
Revision ID: 3358eb6834e6
|
||||
Revises: 067c4002f2d9
|
||||
Create Date: 2022-12-27 14:44:59.483868
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3358eb6834e6"
|
||||
down_revision = "067c4002f2d9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"journal",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"event_payload",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["person_id"],
|
||||
["person.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["post_id"],
|
||||
["post.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False)
|
||||
op.create_table(
|
||||
"journal_integration",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("last_run", sa.DateTime(), nullable=True),
|
||||
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False),
|
||||
sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("next_run", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["last_journal_id"],
|
||||
["journal.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "description"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("journal_integration")
|
||||
op.drop_index(op.f("ix_journal_person_id"), table_name="journal")
|
||||
op.drop_table("journal")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import tasks
|
||||
from oasst_backend.api.v1 import tasks, text_labels
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
||||
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
|
||||
|
||||
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models.db_payload import TaskPayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -219,10 +218,6 @@ def post_interaction(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
|
||||
work_payload: TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# here we store the text reply in the database
|
||||
# ToDo: role user or agent?
|
||||
pr.store_text_reply(interaction, role="unknown")
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LabelTextRequest(pydantic.BaseModel):
|
||||
text_labels: protocol_schema.TextLabels
|
||||
user: protocol_schema.User
|
||||
|
||||
|
||||
@router.post("/")
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
request: LabelTextRequest,
|
||||
) -> None:
|
||||
"""
|
||||
Label a piece of text.
|
||||
"""
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {request=}.")
|
||||
pr = PromptRepository(db, api_client, user=request.user)
|
||||
pr.store_text_labels(request.text_labels)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to store label.")
|
||||
raise HTTPException(
|
||||
status_code=HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
@@ -0,0 +1,122 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
class JournalEventType(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
user_created = "user_created"
|
||||
text_reply_to_post = "text_reply_to_post"
|
||||
post_rating = "post_rating"
|
||||
post_ranking = "post_ranking"
|
||||
|
||||
|
||||
@payload_type
|
||||
class JournalEvent(BaseModel):
|
||||
type: str
|
||||
person_id: Optional[UUID]
|
||||
post_id: Optional[UUID]
|
||||
workpackage_id: Optional[UUID]
|
||||
task_type: Optional[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class TextReplyEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
|
||||
length: int
|
||||
role: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
|
||||
rating: int
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
class JournalWriter:
|
||||
def __init__(self, db: Session, api_client: ApiClient, person: Person):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = person
|
||||
self.person_id = self.person.id if self.person else None
|
||||
|
||||
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_post,
|
||||
payload=TextReplyEvent(role=role, length=length),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_rating,
|
||||
payload=RatingEvent(rating=rating),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_ranking,
|
||||
payload=RankingEvent(ranking=ranking),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
def log(
|
||||
self,
|
||||
*,
|
||||
payload: JournalEvent,
|
||||
task_type: str,
|
||||
event_type: str = None,
|
||||
workpackage_id: Optional[UUID] = None,
|
||||
post_id: Optional[UUID] = None,
|
||||
commit: bool = True,
|
||||
) -> Journal:
|
||||
if event_type is None:
|
||||
if payload is None:
|
||||
event_type = "null"
|
||||
else:
|
||||
event_type = type(payload).__name__
|
||||
|
||||
if payload.person_id is None:
|
||||
payload.person_id = self.person_id
|
||||
if payload.post_id is None:
|
||||
payload.post_id = post_id
|
||||
if payload.workpackage_id is None:
|
||||
payload.workpackage_id = workpackage_id
|
||||
if payload.task_type is None:
|
||||
payload.task_type = task_type
|
||||
|
||||
entry = Journal(
|
||||
person_id=self.person_id,
|
||||
api_client_id=self.api_client.id,
|
||||
created_date=utcnow(),
|
||||
event_type=event_type,
|
||||
event_payload=PayloadContainer(payload=payload),
|
||||
post_id=post_id,
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
if commit:
|
||||
self.db.commit()
|
||||
|
||||
return entry
|
||||
@@ -1,9 +1,11 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .text_labels import TextLabels
|
||||
from .work_package import WorkPackage
|
||||
|
||||
__all__ = [
|
||||
@@ -13,4 +15,7 @@ __all__ = [
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid1, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
def generate_time_uuid(node=None, clock_seq=None):
|
||||
"""Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID."""
|
||||
(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields
|
||||
# reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122
|
||||
timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low
|
||||
version = time_hi_version >> 12
|
||||
assert version == 1
|
||||
a = timestamp >> 28 # bits 28-59
|
||||
b = (timestamp >> 12) & 0xFFFF # bits 12-27
|
||||
c = timestamp & 0xFFF # bits 0-11 (clear version bits)
|
||||
clock_seq_hi_variant &= 0xF # (clear variant bits)
|
||||
return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None)
|
||||
|
||||
|
||||
class Journal(SQLModel, table=True):
|
||||
__tablename__ = "journal"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
event_type: str = Field(nullable=False, max_length=200)
|
||||
event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
|
||||
class JournalIntegration(SQLModel, table=True):
|
||||
__tablename__ = "journal_integration"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
description: str = Field(max_length=512, primary_key=True)
|
||||
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: str = Field(nullable=True)
|
||||
next_run: datetime = Field(nullable=True)
|
||||
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class TextLabels(SQLModel, table=True):
|
||||
__tablename__ = "text_labels"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
@@ -5,7 +5,8 @@ from uuid import UUID, uuid4
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -17,6 +18,7 @@ class PromptRepository:
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.journal = JournalWriter(db, api_client, self.person)
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
@@ -116,6 +118,10 @@ class PromptRepository:
|
||||
self.validate_post_id(reply.post_id)
|
||||
self.validate_post_id(reply.user_post_id)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(reply.post_id)
|
||||
work_payload: db_payload.TaskPayload = work_package.payload.payload
|
||||
logger.info(f"found task work package in db: {work_payload}")
|
||||
|
||||
# find post with post-id
|
||||
parent_post: Post = (
|
||||
self.db.query(Post)
|
||||
@@ -141,6 +147,7 @@ class PromptRepository:
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=reply.text),
|
||||
)
|
||||
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
|
||||
return user_post
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
@@ -159,6 +166,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
return reaction
|
||||
|
||||
@@ -184,6 +192,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
@@ -199,6 +208,7 @@ class PromptRepository:
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
@@ -314,3 +324,17 @@ class PromptRepository:
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
|
||||
model = TextLabels(
|
||||
api_client_id=self.api_client.id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
if text_labels.has_post_id:
|
||||
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
|
||||
model.post_id = text_labels.post_id
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
@@ -5,6 +5,7 @@ numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
scipy==1.8.1
|
||||
SQLAlchemy==1.4.41
|
||||
sqlmodel==0.0.8
|
||||
starlette==0.22.0
|
||||
|
||||
@@ -204,3 +204,51 @@ AnyInteraction = Union[
|
||||
PostRating,
|
||||
PostRanking,
|
||||
]
|
||||
|
||||
|
||||
class TextLabel(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
spam = "spam"
|
||||
violence = "violence"
|
||||
sexual_content = "sexual_content"
|
||||
toxicity = "toxicity"
|
||||
political_content = "political_content"
|
||||
humor = "humor"
|
||||
sarcasm = "sarcasm"
|
||||
hate_speech = "hate_speech"
|
||||
profanity = "profanity"
|
||||
ad_hominem = "ad_hominem"
|
||||
insult = "insult"
|
||||
threat = "threat"
|
||||
aggressive = "aggressive"
|
||||
misleading = "misleading"
|
||||
helpful = "helpful"
|
||||
formal = "formal"
|
||||
cringe = "cringe"
|
||||
creative = "creative"
|
||||
beautiful = "beautiful"
|
||||
informative = "informative"
|
||||
based = "based"
|
||||
slang = "slang"
|
||||
|
||||
|
||||
class TextLabels(BaseModel):
|
||||
"""A set of labels for a piece of text."""
|
||||
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
post_id: str | None = None
|
||||
|
||||
@property
|
||||
def has_post_id(self) -> bool:
|
||||
"""Whether this TextLabels has a post_id."""
|
||||
return bool(self.post_id)
|
||||
|
||||
# check that each label value is between 0 and 1
|
||||
@pydantic.validator("labels")
|
||||
def check_label_values(cls, v):
|
||||
for key, value in v.items():
|
||||
if not (0 <= value <= 1):
|
||||
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
|
||||
return v
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Return the current utc date and time with tzinfo set to UTC."""
|
||||
return datetime.now(timezone.utc)
|
||||
@@ -9,6 +9,7 @@ services:
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||
interval: 2s
|
||||
|
||||
@@ -16,6 +16,7 @@ services:
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: ocgpt_website
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||
interval: 2s
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
from scipy import log2
|
||||
from scipy.integrate import nquad
|
||||
from scipy.special import gammaln, psi
|
||||
from scipy.stats import dirichlet
|
||||
|
||||
|
||||
def make_range(*x):
|
||||
"""
|
||||
constructs leftover values for the simplex given the first k entries
|
||||
(0,x_k) = 1-(x_1+...+x_(k-1))
|
||||
"""
|
||||
return (0, max(0, 1 - sum(x)))
|
||||
|
||||
|
||||
def relative_entropy(p, q):
|
||||
"""
|
||||
relative entropy of the two given dirichlet distributions
|
||||
"""
|
||||
|
||||
def tmp(*x):
|
||||
"""
|
||||
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
|
||||
Then computes the relative entropy of posterior and prior for that datapoint
|
||||
"""
|
||||
x_new = np.append(x, 1 - sum(x))
|
||||
return p(x_new) * log2(p(x_new) / q(x_new))
|
||||
|
||||
return tmp
|
||||
|
||||
|
||||
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
|
||||
s = np.random.rand(dim - 1, samples)
|
||||
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
|
||||
# print(s)
|
||||
pos = np.diff(s, axis=0)
|
||||
# print(pos)
|
||||
res = fun(pos)
|
||||
return np.mean(res)
|
||||
|
||||
|
||||
def analytic_solution(a_post, a_prior):
|
||||
"""
|
||||
Analytic solution to the KL-divergence between two dirichlet distributions.
|
||||
Proof is in the Notion design doc.
|
||||
"""
|
||||
post_sum = np.sum(a_post)
|
||||
prior_sum = np.sum(a_prior)
|
||||
info = (
|
||||
gammaln(post_sum)
|
||||
- gammaln(prior_sum)
|
||||
- np.sum(gammaln(a_post))
|
||||
+ np.sum(gammaln(a_prior))
|
||||
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def infogain(a_post, a_prior):
|
||||
raise (
|
||||
"""For the love of good don't use this:
|
||||
it's insanely poorly conditioned, the worst numerical code I have ever written
|
||||
and it's slow as molasses. Use the analytic solution instead.
|
||||
|
||||
Maybe remove
|
||||
"""
|
||||
)
|
||||
args = len(a_prior)
|
||||
p = dirichlet(a_post).pdf
|
||||
q = dirichlet(a_prior).pdf
|
||||
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
|
||||
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
|
||||
return info
|
||||
|
||||
|
||||
def uniform_expected_infogain(a_prior):
|
||||
mean_weight = dirichlet.mean(a_prior)
|
||||
print("weight", mean_weight)
|
||||
results = []
|
||||
for i, w in enumerate(mean_weight):
|
||||
a_post = a_prior.copy()
|
||||
a_post[i] = a_post[i] + 1
|
||||
results.append(w * analytic_solution(a_post, a_prior))
|
||||
return np.sum(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
|
||||
print("algebraic", analytic_solution(a_post, a_prior))
|
||||
# print("raw",infogain(a_post, a_prior))
|
||||
print("large infogain", uniform_expected_infogain(a_prior))
|
||||
print("post infogain", uniform_expected_infogain(a_post))
|
||||
# a_prior = np.array([1,1,1000])
|
||||
# print("small infogain",uniform_expected_infogain(a_prior))
|
||||
@@ -68,7 +68,7 @@ def get_winner(pairs):
|
||||
def get_ranking(pairs):
|
||||
"""
|
||||
Abuses concordance property to get a (not necessarily unqiue) ranking.
|
||||
The lack of uniqueness is due to the potential existance of multiple
|
||||
The lack of uniqueness is due to the potential existence of multiple
|
||||
equally ranked winners. We have to pick one, which is where
|
||||
the non-uniqueness comes from
|
||||
"""
|
||||
@@ -99,7 +99,7 @@ def ranked_pairs(ranks: List[List[int]]):
|
||||
tallies = tallies - tallies.T
|
||||
# print(tallies)
|
||||
# note: the resulting tally matrix should be skew-symmetric
|
||||
# order by strenght of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
sorted_majorities = []
|
||||
for i in range(len(ranks[0])):
|
||||
for j in range(len(ranks[i])):
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from scipy.stats import kendalltau
|
||||
|
||||
|
||||
@dataclass
|
||||
class Voter:
|
||||
"""
|
||||
Represents a single voter.
|
||||
This tabulates the number of good votes, total votes,
|
||||
and points.
|
||||
We only put well-behaved people on the scoreboard and filter out the badly behaved ones
|
||||
"""
|
||||
|
||||
uid: Any
|
||||
num_votes: int
|
||||
num_good_votes: int
|
||||
num_prompts: int
|
||||
num_good_prompts: int
|
||||
num_rankings: int
|
||||
num_good_rankings: int
|
||||
|
||||
#####################
|
||||
voting_points: int
|
||||
prompt_points: int
|
||||
ranking_points: int
|
||||
|
||||
def voter_quality(self):
|
||||
return self.num_good_votes / self.num_votes
|
||||
|
||||
def rank_quality(self):
|
||||
return self.num_good_rankings / self.num_rankings
|
||||
|
||||
def prompt_quality(self):
|
||||
return self.num_good_prompts / self.num_prompts
|
||||
|
||||
def is_well_behaved(self, threshhold_vote, threshhold_prompt, threshhold_rank):
|
||||
return (
|
||||
self.voter_quality() > threshhold_vote
|
||||
and self.prompt_quality() > threshhold_prompt
|
||||
and self.rank_quality() > threshhold_rank
|
||||
)
|
||||
|
||||
def total_points(self, voting_weight, prompt_weight, ranking_weight):
|
||||
return (
|
||||
voting_weight * self.voting_points
|
||||
+ prompt_weight * self.prompt_points
|
||||
+ ranking_weight * self.ranking_points
|
||||
)
|
||||
|
||||
|
||||
def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
"""
|
||||
This function returns the new "quality score" and points for a voter,
|
||||
after that voter cast a vote on a question.
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
|
||||
The consensus is the array of all votes cast by all voters for that question
|
||||
We then update the voter data using the new information
|
||||
|
||||
Parameters:
|
||||
new_vote (int): the index of the vote cast by the voter
|
||||
consensus (ArrayLike): all votes cast for this question
|
||||
voter_data (Voter): a "Voter" object that represents the person casting the "new_vote"
|
||||
|
||||
Returns:
|
||||
updated_voter (Voter): the new "quality score" and points for the voter
|
||||
"""
|
||||
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
|
||||
# since 100 is the lowest, 300 the highest and 200 the middle value
|
||||
consensus_ranking = np.argsort(np.argsort(consensus))
|
||||
new_points = consensus_ranking[new_vote] + voter_data.voting_points
|
||||
|
||||
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
|
||||
# it's a good vote
|
||||
new_good_votes = int(consensus_ranking[new_vote] > (len(consensus) - 1) / 2) + voter_data.num_good_votes
|
||||
new_num_votes = voter_data.num_votes + 1
|
||||
return replace(voter_data, num_votes=new_num_votes, num_good_votes=new_good_votes, voting_points=new_points)
|
||||
|
||||
|
||||
def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
"""
|
||||
This function returns the gain of points for a given prompt's votes
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
|
||||
Parameters:
|
||||
consensus (ArrayLike): all votes cast for this question
|
||||
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
|
||||
|
||||
Returns:
|
||||
updated_voter (Voter): the new "quality score" and points for the voter
|
||||
"""
|
||||
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
|
||||
# since 100 is the lowest, 300 the highest and 200 the middle value
|
||||
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
|
||||
delta_votes = np.sum(consensus_ranking * consensus)
|
||||
new_points = delta_votes + voter_data.prompt_points
|
||||
|
||||
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
|
||||
# it's a good vote
|
||||
new_good_prompts = int(delta_votes > 0) + voter_data.num_good_prompts
|
||||
new_num_prompts = voter_data.num_prompts + 1
|
||||
return replace(
|
||||
voter_data,
|
||||
num_prompts=new_num_prompts,
|
||||
num_good_prompts=new_good_prompts,
|
||||
prompt_points=new_points,
|
||||
)
|
||||
|
||||
|
||||
def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.ArrayLike, voter_data: Voter) -> Voter:
|
||||
"""
|
||||
This function returns the gain of points for a given ranking's votes
|
||||
|
||||
This function is only to be run when archiving a question
|
||||
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
|
||||
|
||||
we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings
|
||||
we use this over spearman correlation since:
|
||||
"[Kendall's τ] approaches a normal distribution more rapidly than ρ, as N, the sample size, increases;
|
||||
and τ is also more tractable mathematically, particularly when ties are present"
|
||||
Gilpin, A. R. (1993). Table for conversion of Kendall's Tau to Spearman's
|
||||
Rho within the context measures of magnitude of effect for meta-analysis
|
||||
|
||||
Further in
|
||||
"research design and statistical analyses, second edition, 2003"
|
||||
the authors note that at least from an significance test POV they will yield the same p-values
|
||||
|
||||
Parameters:
|
||||
user_ranking (ArrayLike): ranking produced by the user
|
||||
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
|
||||
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
|
||||
|
||||
Returns:
|
||||
updated_voter (Voter): the new "quality score" and points for the voter
|
||||
"""
|
||||
bubble_sort_distance, p_value = kendalltau(user_ranking, consensus_ranking)
|
||||
# normalize kendall-tau from [-1,1] into [0,1] range
|
||||
bubble_sort_distance = (1 + bubble_sort_distance) / 2
|
||||
new_points = bubble_sort_distance + voter_data.ranking_points
|
||||
new_good_rankings = int(bubble_sort_distance > 0.5) + voter_data.num_good_rankings
|
||||
new_num_rankings = voter_data.num_rankings + 1
|
||||
return replace(
|
||||
voter_data,
|
||||
num_rankings=new_num_rankings,
|
||||
num_good_rankings=new_good_rankings,
|
||||
ranking_points=new_points,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_voter = Voter(
|
||||
"abc",
|
||||
num_votes=10,
|
||||
num_good_votes=2,
|
||||
num_prompts=10,
|
||||
num_good_prompts=2,
|
||||
num_rankings=10,
|
||||
num_good_rankings=2,
|
||||
voting_points=6,
|
||||
prompt_points=0,
|
||||
ranking_points=0,
|
||||
)
|
||||
new_vote = 3
|
||||
consensus = np.array([200, 300, 100, 500])
|
||||
print(demo_voter)
|
||||
print("best vote ", score_update_votes(new_vote, consensus, demo_voter))
|
||||
new_vote = 2
|
||||
print("worst vote ", score_update_votes(new_vote, consensus, demo_voter))
|
||||
new_vote = 1
|
||||
print("medium vote ", score_update_votes(new_vote, consensus, demo_voter))
|
||||
print("prompt writer", score_update_prompts(consensus, demo_voter))
|
||||
print("best rank ", score_update_ranking(np.array([0, 2, 1]), np.array([0, 2, 1]), demo_voter))
|
||||
print("medium rank ", score_update_ranking(np.array([2, 0, 1]), np.array([0, 2, 1]), demo_voter))
|
||||
print("worst rank ", score_update_ranking(np.array([1, 0, 2]), np.array([0, 2, 1]), demo_voter))
|
||||
+1
-1
@@ -5,7 +5,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5433/ocgpt_website
|
||||
FASTAPI_URL=http://localhost:8080
|
||||
FASTAPI_KEY=1234
|
||||
|
||||
# A dev Auth Secret. Can be exposed if we never use this publically.
|
||||
# A dev Auth Secret. Can be exposed if we never use this publicly.
|
||||
NEXTAUTH_SECRET=O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
|
||||
|
||||
# The SMTP host and port found by running the jobs in /scripts/frontend-development/docker-compose.yaml
|
||||
|
||||
@@ -63,6 +63,14 @@ If you're doing active development we suggest the following workflow:
|
||||
navigate to `http://localhost:1080`. Check the email listed and click the
|
||||
log in link. You're now logged in and authenticated.
|
||||
|
||||
### Using debug user credentials
|
||||
|
||||
Whenever the website runs in development mode, you can use the debug credentials provider to log in without fancy emails or OAuth.
|
||||
|
||||
1. Development mode is automatically active when you start the website with `npm run dev`.
|
||||
1. Use the `Login` button in the top right to go to the login page.
|
||||
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
|
||||
|
||||
## Code Layout
|
||||
|
||||
### React Code
|
||||
|
||||
@@ -12,7 +12,7 @@ export function Avatar() {
|
||||
return <></>;
|
||||
}
|
||||
if (session && session.user) {
|
||||
const email = session.user.email;
|
||||
const displayName = session.user.name || session.user.email;
|
||||
const accountOptions = [
|
||||
{
|
||||
name: "Account Settings",
|
||||
@@ -35,7 +35,7 @@ export function Avatar() {
|
||||
height="40"
|
||||
className="rounded-full"
|
||||
></Image>
|
||||
<p className="hidden lg:flex">{email}</p>
|
||||
<p className="hidden lg:flex">{displayName}</p>
|
||||
{/* Will be changed to username once it is implemented */}
|
||||
</div>
|
||||
</Popover.Button>
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import clsx from "clsx";
|
||||
|
||||
export const Button = (
|
||||
props: React.DetailedHTMLProps<React.ButtonHTMLAttributes<HTMLButtonElement>, HTMLButtonElement>
|
||||
) => {
|
||||
const { className, children, ...rest } = props;
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className={clsx(
|
||||
"inline-flex items-center rounded-md border border-transparent px-4 py-2 text-sm font-medium focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2",
|
||||
className
|
||||
)}
|
||||
{...rest}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
);
|
||||
};
|
||||
@@ -12,13 +12,7 @@ export function Footer() {
|
||||
<div>
|
||||
<div className="flex items-center text-gray-900">
|
||||
<Link href="/" aria-label="Home" className="flex items-center">
|
||||
<Image
|
||||
src="/images/logos/CHAT-THOUGHT-LOGO.svg"
|
||||
className="mx-auto object-fill"
|
||||
width="50"
|
||||
height="50"
|
||||
alt="logo"
|
||||
/>
|
||||
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
|
||||
</Link>
|
||||
|
||||
<div className="ml-4">
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
export interface Message {
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
}
|
||||
|
||||
const getColor = (isAssistant: boolean) => (isAssistant ? "bg-slate-800" : "bg-sky-900");
|
||||
|
||||
export const Messages = ({ messages }: { messages: Message[] }) => {
|
||||
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
|
||||
return (
|
||||
<div key={i + text} className={`${getColor(is_assistant)} p-4 my-1 rounded-xl text-white whitespace-pre-wrap`}>
|
||||
{text}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <>{items}</>;
|
||||
};
|
||||
@@ -0,0 +1,12 @@
|
||||
const RankItem = ({ username, score }) => {
|
||||
return (
|
||||
<div className="flex flex-row justify-between p-6 border-2 border-slate-100 text-left font-semibold hover:bg-sky-50">
|
||||
<div>1</div>
|
||||
<div>@username</div>
|
||||
<div>20.5</div>
|
||||
<div>gold</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default RankItem;
|
||||
@@ -0,0 +1,39 @@
|
||||
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
export type OptionProps = {
|
||||
img: string;
|
||||
alt: string;
|
||||
title: string;
|
||||
link: string;
|
||||
};
|
||||
|
||||
export const TaskOption = (props: OptionProps) => {
|
||||
const { alt, img, title, link } = props;
|
||||
return (
|
||||
<Link href={link}>
|
||||
<Card
|
||||
maxW="300"
|
||||
minW="300"
|
||||
minH="300"
|
||||
maxH="300"
|
||||
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
|
||||
>
|
||||
<CardBody width="full" height="full">
|
||||
<Flex direction="column" alignItems="center" justifyContent="center">
|
||||
<Image src={img} alt={alt} width={200} height={200} />
|
||||
<Heading
|
||||
mt={-10}
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
textAlign="center"
|
||||
fontSize="3xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
</Flex>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Link>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,23 @@
|
||||
import { Divider, Flex, Heading } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
export type TaskOptionsProps = {
|
||||
title: string;
|
||||
children: JSX.Element | JSX.Element[];
|
||||
};
|
||||
|
||||
export const TaskOptions = (props: TaskOptionsProps) => {
|
||||
const { title, children } = props;
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="center">
|
||||
<Heading
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
fontSize="5xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
<Divider mt={-8} />
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,29 @@
|
||||
import React from "react";
|
||||
import { TaskOptions } from "./TaskOptions";
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { TaskOption } from "./TaskOption";
|
||||
|
||||
export const TaskSelection = () => {
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
|
||||
<TaskOptions key="create" title="Create">
|
||||
<TaskOption
|
||||
alt="Summarize Stories"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Summarize stories"
|
||||
link="/summarize/story"
|
||||
/>
|
||||
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
|
||||
<TaskOption
|
||||
alt="Reply as Assistant"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Reply as Assistant"
|
||||
link="/create/assistant_reply"
|
||||
/>
|
||||
</TaskOptions>
|
||||
<TaskOptions key="evaluate" title="Evaluate">
|
||||
<TaskOption alt="Rate Prompts" img="/images/logos/logo.svg" title="Rate Prompts" link="/grading/grade-output" />
|
||||
</TaskOptions>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,3 @@
|
||||
export { TaskSelection } from "./TaskSelection";
|
||||
export { TaskOptions } from "./TaskOptions";
|
||||
export { TaskOption } from "./TaskOption";
|
||||
@@ -0,0 +1,14 @@
|
||||
export const TwoColumns = ({ children }: { children: React.ReactNode[] }) => {
|
||||
if (!Array.isArray(children) || children.length !== 2) {
|
||||
throw new Error("TwoColumns expects 2 children");
|
||||
}
|
||||
|
||||
const [first, second] = children;
|
||||
|
||||
return (
|
||||
<section className="mb-8 lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white p-6">{first}</div>
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white p-6 mt-6 lg:mt-0">{second}</div>
|
||||
</section>
|
||||
);
|
||||
};
|
||||
@@ -2,6 +2,7 @@ import type { AuthOptions } from "next-auth";
|
||||
import NextAuth from "next-auth";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import CredentialsProvider from "next-auth/providers/credentials";
|
||||
import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
||||
|
||||
import prisma from "src/lib/prismadb";
|
||||
@@ -32,6 +33,23 @@ if (process.env.DISCORD_CLIENT_ID) {
|
||||
);
|
||||
}
|
||||
|
||||
if (process.env.NODE_ENV === "development") {
|
||||
providers.push(
|
||||
CredentialsProvider({
|
||||
name: "Debug Credentials",
|
||||
credentials: {
|
||||
username: { label: "Username", type: "text" },
|
||||
},
|
||||
async authorize(credentials) {
|
||||
return {
|
||||
id: credentials.username,
|
||||
name: credentials.username,
|
||||
};
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
export const authOptions: AuthOptions = {
|
||||
// Ensure we can store user data in a database.
|
||||
adapter: PrismaAdapter(prisma),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { FaDiscord, FaEnvelope, FaBug, FaGithub } from "react-icons/fa";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import { useRef } from "react";
|
||||
import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
@@ -8,19 +9,28 @@ import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Signin({ csrfToken, providers }) {
|
||||
const { discord, email, github } = providers;
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const debugUsernameEl = useRef(null);
|
||||
|
||||
const signinWithDiscord = () => {
|
||||
signIn(discord.id, { callbackUrl: "/" });
|
||||
};
|
||||
const signinWithEmail = () => {
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
|
||||
};
|
||||
|
||||
const signinWithGithub = () => {
|
||||
signIn(github.id, { callbackUrl: "/" });
|
||||
};
|
||||
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -28,14 +38,27 @@ export default function Signin({ csrfToken, providers }) {
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<Stack spacing="2">
|
||||
<Stack spacing="6">
|
||||
{credentials && (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
|
||||
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
|
||||
Continue with Debug User
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
{email && (
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray" onClick={signinWithEmail}>
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
|
||||
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray">
|
||||
Continue with Email
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
{discord && (
|
||||
<Button
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const { isLoading } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
console.log(data);
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputRef.current.value.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_post",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
return (
|
||||
<div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} />
|
||||
</>
|
||||
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumns>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
|
||||
<b>Prompt</b>
|
||||
<span>{tasks[0].id}</span>
|
||||
<b>Output</b>
|
||||
<span>Submit your answer</span>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-center ml-auto">
|
||||
<Button className="mr-2 bg-indigo-100 text-indigo-700 hover:bg-indigo-200">Skip</Button>
|
||||
<Button
|
||||
onClick={() => submitResponse(tasks[0])}
|
||||
className="bg-indigo-600 text-white shadow-sm hover:bg-indigo-700"
|
||||
>
|
||||
Submit
|
||||
</Button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default AssistantReply;
|
||||
@@ -0,0 +1,84 @@
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TwoColumns } from "src/components/TwoColumns";
|
||||
import { Button } from "src/components/Button";
|
||||
|
||||
const UserReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const { isLoading } = useSWRImmutable("/api/new_task/user_reply", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
console.log(data);
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputRef.current.value.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
update_type: "text_reply_to_post",
|
||||
content: {
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (tasks.length == 0) {
|
||||
return <div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
return (
|
||||
<div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
|
||||
<TwoColumns>
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">Reply as a user</h5>
|
||||
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
|
||||
<Messages messages={task.conversation.messages} />
|
||||
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
|
||||
</>
|
||||
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
|
||||
</TwoColumns>
|
||||
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
|
||||
<b>Prompt</b>
|
||||
<span>{tasks[0].id}</span>
|
||||
<b>Output</b>
|
||||
<span>Submit your answer</span>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-center ml-auto">
|
||||
<Button className="mr-2 bg-indigo-100 text-indigo-700 hover:bg-indigo-200">Skip</Button>
|
||||
<Button
|
||||
onClick={() => submitResponse(tasks[0])}
|
||||
className="bg-indigo-600 text-white shadow-sm hover:bg-indigo-700"
|
||||
>
|
||||
Submit
|
||||
</Button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default UserReply;
|
||||
@@ -1,16 +1,11 @@
|
||||
import { useSession } from "next-auth/react";
|
||||
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
|
||||
import { CallToAction } from "../components/CallToAction";
|
||||
import { Faq } from "../components/Faq";
|
||||
import { Footer } from "../components/Footer";
|
||||
import { Header } from "../components/Header";
|
||||
import { Hero } from "../components/Hero";
|
||||
|
||||
import styles from "../styles/Home.module.css";
|
||||
import { TaskSelection } from "../components/TaskSelection";
|
||||
|
||||
export default function Home() {
|
||||
const { data: session } = useSession();
|
||||
@@ -46,10 +41,8 @@ export default function Home() {
|
||||
/>
|
||||
</Head>
|
||||
<Header />
|
||||
<main className="h-3/4 z-0 bg-white flex items-center justify-center">
|
||||
<Button size="lg" colorScheme="blue" className="drop-shadow">
|
||||
<Link href="/grading/grade-output">Rate a prompt and output now</Link>
|
||||
</Button>
|
||||
<main className="h-3/4 m-20 z-0 bg-white flex flex-col items-center justify-center gap-2">
|
||||
<TaskSelection />
|
||||
</main>
|
||||
<Footer />
|
||||
</>
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
import RankItem from "@/components/RankItem";
|
||||
import { BarsArrowUpIcon, BarsArrowDownIcon } from "@heroicons/react/24/solid";
|
||||
import Image from "next/image";
|
||||
import { HiBarsArrowUp, HiBarsArrowDown } from "react-icons/hi2";
|
||||
|
||||
const LeaderBoard = () => {
|
||||
const PlaceHolderProps = { username: "test_user", score: 10 };
|
||||
return (
|
||||
<div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">
|
||||
<div className="flex flex-col">
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white">
|
||||
<div className="p-8">
|
||||
<h5 className="text-2xl font-bold">LeaderBoard</h5>
|
||||
</div>
|
||||
<div className="flex flex-row justify-between px-6 py-3 font-semibold text-md">
|
||||
<div className="flex flex-row items-center justify-center space-x-2">
|
||||
<div>
|
||||
<p>Rank</p>
|
||||
</div>
|
||||
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
|
||||
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center space-x-2">
|
||||
<div>
|
||||
<p>User</p>
|
||||
</div>
|
||||
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
|
||||
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center space-x-2">
|
||||
<div>
|
||||
<p>Score</p>
|
||||
</div>
|
||||
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
|
||||
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center space-x-2">
|
||||
<div>
|
||||
<p>Medal</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{/* leaderboard items */}
|
||||
<RankItem {...PlaceHolderProps}></RankItem>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default LeaderBoard;
|
||||
@@ -0,0 +1,118 @@
|
||||
// TODO(#65): Unify and simplify the task paths
|
||||
import { Textarea } from "@chakra-ui/react";
|
||||
import { useRef, useState } from "react";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
|
||||
const SummarizeStory = () => {
|
||||
// Use an array of tasks that record the sequence of steps until a task is
|
||||
// deemed complete.
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
// Fetch the very fist task. We can ignore everything except isLoading
|
||||
// because the onSuccess handler will update `tasks` when ready.
|
||||
const { isLoading } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
console.log(data);
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
// Every time we submit an answer to the latest task, let the backend handle
|
||||
// all the interactions then add the resulting task to the queue. This ends
|
||||
// when we hit the done task.
|
||||
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
// This is the more efficient way to update a react state array.
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
// Trigger a mutation that updates the current task. We should probably
|
||||
// signal somewhere that this interaction is being processed.
|
||||
const submitResponse = (task: { id: string }) => {
|
||||
const text = inputRef.current.value.trim();
|
||||
trigger({
|
||||
id: task.id,
|
||||
content: {
|
||||
update_type: "text_reply_to_post",
|
||||
text,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* TODO: Make this a nicer loading screen.
|
||||
*/
|
||||
if (tasks.length == 0) {
|
||||
return <div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">
|
||||
{/* Instrunction and Output panels */}
|
||||
<section className="mb-8 lt-lg:mb-12 ">
|
||||
<div className="grid lg:gap-x-12 lg:grid-cols-2">
|
||||
{/* Instruction panel */}
|
||||
<div className="rounded-lg shadow-lg h-full block bg-white">
|
||||
<div className="p-6">
|
||||
<h5 className="text-lg font-semibold">Instruction</h5>
|
||||
<p className="text-lg py-1">Summarize the following story</p>
|
||||
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.story}</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Output panel */}
|
||||
<div className="mt-6 lg:mt-0 rounded-lg shadow-lg h-full block bg-white">
|
||||
<div className="flex justify-center p-6">
|
||||
<Textarea name="summary" placeholder="Summary" ref={inputRef} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{/* Info & controls */}
|
||||
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
|
||||
<div className="flex flex-col justify-self-start text-gray-700">
|
||||
<div>
|
||||
<span>
|
||||
<b>Prompt</b>
|
||||
</span>
|
||||
<span className="ml-2">{tasks[0].id}</span>
|
||||
</div>
|
||||
<div>
|
||||
<span>
|
||||
<b>Output</b>
|
||||
</span>
|
||||
<span className="ml-2">Submit your answer</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Skip / Submit controls */}
|
||||
<div className="flex justify-center ml-auto">
|
||||
<button
|
||||
type="button"
|
||||
className="mr-2 inline-flex items-center rounded-md border border-transparent bg-indigo-100 px-4 py-2 text-sm font-medium text-indigo-700 hover:bg-indigo-200 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2"
|
||||
>
|
||||
Skip
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => submitResponse(tasks[0])}
|
||||
className="inline-flex items-center rounded-md border border-transparent bg-indigo-600 px-4 py-2 text-sm font-medium text-white shadow-sm hover:bg-indigo-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2"
|
||||
>
|
||||
Submit
|
||||
</button>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default SummarizeStory;
|
||||
Reference in New Issue
Block a user