Merge branch 'main' into 24-web-deploy-aws

This commit is contained in:
Keith Stevens
2022-12-28 08:04:23 +09:00
committed by GitHub
40 changed files with 1302 additions and 66 deletions
+1
View File
@@ -0,0 +1 @@
* text=auto eol=lf
-28
View File
@@ -60,31 +60,3 @@ In case you haven't done this, have already committed, and CI is failing, you ca
### Deployment
Upon making a release on GitHub, all docker images are automatically built and pushed to ghcr.io. The docker images are tagged with the release version, and the `latest` tag. Further, the ansible playbook in `ansible/dev.yaml` is run to automatically deploy the built release to the dev machine.
# (Older version of the readme below)
## How do I start helping out?
Check out these pages to learn more about the project.
Ping Birger on discord if you want help to get started.
http://**discordapp.com/users/birger#6875**
## More information in the notion
https://roan-iguanadon-a58.notion.site/Open-Chat-Gpt-83dd217eeeb84907a155b8a9d716fa46
## Code structure
### Bot
We have a folder named bot where code related to the bot lives.
### Backend
We have a backend folder for backend development of the api that the discord bot sends it information to.
### Website
We have a folder for the website, live at https://projects.laion.ai/Open-Chat-GPT/ .The website is built using Next.js
+1
View File
@@ -7,6 +7,7 @@ Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Adds text labels table.
Revision ID: 067c4002f2d9
Revises: 0daec5f8135f
Create Date: 2022-12-25 17:05:21.208843
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "067c4002f2d9"
down_revision = "0daec5f8135f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"text_labels",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("post_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("labels", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("text", sqlmodel.sql.sqltypes.AutoString(length=65536), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["post_id"],
["post.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("text_labels")
# ### end Alembic commands ###
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
"""add_journal_table
Revision ID: 3358eb6834e6
Revises: 067c4002f2d9
Create Date: 2022-12-27 14:44:59.483868
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3358eb6834e6"
down_revision = "067c4002f2d9"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"journal",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"created_date", sa.DateTime(timezone=True), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False
),
sa.Column(
"event_payload",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("person_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("post_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("event_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["person_id"],
["person.id"],
),
sa.ForeignKeyConstraint(
["post_id"],
["post.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_journal_person_id"), "journal", ["person_id"], unique=False)
op.create_table(
"journal_integration",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("last_run", sa.DateTime(), nullable=True),
sa.Column("description", sqlmodel.sql.sqltypes.AutoString(length=512), nullable=False),
sa.Column("last_journal_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("last_error", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("next_run", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["last_journal_id"],
["journal.id"],
),
sa.PrimaryKeyConstraint("id", "description"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("journal_integration")
op.drop_index(op.f("ix_journal_person_id"), table_name="journal")
op.drop_table("journal")
# ### end Alembic commands ###
+2 -1
View File
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from fastapi import APIRouter
from oasst_backend.api.v1 import tasks
from oasst_backend.api.v1 import tasks, text_labels
api_router = APIRouter()
api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
api_router.include_router(text_labels.router, prefix="/text_labels", tags=["text_labels"])
-5
View File
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.models.db_payload import TaskPayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -219,10 +218,6 @@ def post_interaction(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
)
work_package = pr.fetch_workpackage_by_postid(interaction.post_id)
work_payload: TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# here we store the text reply in the database
# ToDo: role user or agent?
pr.store_text_reply(interaction, role="unknown")
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
import pydantic
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
router = APIRouter()
class LabelTextRequest(pydantic.BaseModel):
text_labels: protocol_schema.TextLabels
user: protocol_schema.User
@router.post("/")
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
request: LabelTextRequest,
) -> None:
"""
Label a piece of text.
"""
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {request=}.")
pr = PromptRepository(db, api_client, user=request.user)
pr.store_text_labels(request.text_labels)
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
+122
View File
@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_shared.utils import utcnow
from pydantic import BaseModel
from sqlmodel import Session
class JournalEventType(str, enum.Enum):
"""A label for a piece of text."""
user_created = "user_created"
text_reply_to_post = "text_reply_to_post"
post_rating = "post_rating"
post_ranking = "post_ranking"
@payload_type
class JournalEvent(BaseModel):
type: str
person_id: Optional[UUID]
post_id: Optional[UUID]
workpackage_id: Optional[UUID]
task_type: Optional[str]
@payload_type
class TextReplyEvent(JournalEvent):
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
length: int
role: str
@payload_type
class RatingEvent(JournalEvent):
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
rating: int
@payload_type
class RankingEvent(JournalEvent):
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
ranking: list[int]
class JournalWriter:
def __init__(self, db: Session, api_client: ApiClient, person: Person):
self.db = db
self.api_client = api_client
self.person = person
self.person_id = self.person.id if self.person else None
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.text_reply_to_post,
payload=TextReplyEvent(role=role, length=length),
workpackage_id=work_package.id,
post_id=post_id,
)
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_rating,
payload=RatingEvent(rating=rating),
workpackage_id=work_package.id,
post_id=post_id,
)
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_ranking,
payload=RankingEvent(ranking=ranking),
workpackage_id=work_package.id,
post_id=post_id,
)
def log(
self,
*,
payload: JournalEvent,
task_type: str,
event_type: str = None,
workpackage_id: Optional[UUID] = None,
post_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
if event_type is None:
if payload is None:
event_type = "null"
else:
event_type = type(payload).__name__
if payload.person_id is None:
payload.person_id = self.person_id
if payload.post_id is None:
payload.post_id = post_id
if payload.workpackage_id is None:
payload.workpackage_id = workpackage_id
if payload.task_type is None:
payload.task_type = task_type
entry = Journal(
person_id=self.person_id,
api_client_id=self.api_client.id,
created_date=utcnow(),
event_type=event_type,
event_payload=PayloadContainer(payload=payload),
post_id=post_id,
)
self.db.add(entry)
if commit:
self.db.commit()
return entry
+5
View File
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .text_labels import TextLabels
from .work_package import WorkPackage
__all__ = [
@@ -13,4 +15,7 @@ __all__ = [
"Post",
"PostReaction",
"WorkPackage",
"TextLabels",
"Journal",
"JournalIntegration",
]
+56
View File
@@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid1, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
def generate_time_uuid(node=None, clock_seq=None):
"""Create a lexicographically sortable time ordered custom (non-standard) UUID by reordering the timestamp fields of a version 1 UUID."""
(time_low, time_mid, time_hi_version, clock_seq_hi_variant, clock_seq_low, node) = uuid1(node, clock_seq).fields
# reconstruct 60 bit timestamp, see version 1 uuid: https://www.rfc-editor.org/rfc/rfc4122
timestamp = (time_hi_version & 0xFFF) << 48 | (time_mid << 32) | time_low
version = time_hi_version >> 12
assert version == 1
a = timestamp >> 28 # bits 28-59
b = (timestamp >> 12) & 0xFFFF # bits 12-27
c = timestamp & 0xFFF # bits 0-11 (clear version bits)
clock_seq_hi_variant &= 0xF # (clear variant bits)
return UUID(fields=(a, b, c, clock_seq_hi_variant, clock_seq_low, node), version=None)
class Journal(SQLModel, table=True):
__tablename__ = "journal"
id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), primary_key=True, default=generate_time_uuid),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")
event_type: str = Field(nullable=False, max_length=200)
event_payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
class JournalIntegration(SQLModel, table=True):
__tablename__ = "journal_integration"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
description: str = Field(max_length=512, primary_key=True)
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
last_error: str = Field(nullable=True)
next_run: datetime = Field(nullable=True)
@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
from datetime import datetime
from typing import Optional
from uuid import UUID, uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class TextLabels(SQLModel, table=True):
__tablename__ = "text_labels"
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
+25 -1
View File
@@ -5,7 +5,8 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -17,6 +18,7 @@ class PromptRepository:
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
self.journal = JournalWriter(db, api_client, self.person)
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
@@ -116,6 +118,10 @@ class PromptRepository:
self.validate_post_id(reply.post_id)
self.validate_post_id(reply.user_post_id)
work_package = self.fetch_workpackage_by_postid(reply.post_id)
work_payload: db_payload.TaskPayload = work_package.payload.payload
logger.info(f"found task work package in db: {work_payload}")
# find post with post-id
parent_post: Post = (
self.db.query(Post)
@@ -141,6 +147,7 @@ class PromptRepository:
role=role,
payload=db_payload.PostPayload(text=reply.text),
)
self.journal.log_text_reply(work_package=work_package, post_id=user_post_id, role=role, length=len(reply.text))
return user_post
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
@@ -159,6 +166,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
return reaction
@@ -184,6 +192,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
@@ -199,6 +208,7 @@ class PromptRepository:
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
self.journal.log_ranking(work_package, post_id=post.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
@@ -314,3 +324,17 @@ class PromptRepository:
self.db.commit()
self.db.refresh(reaction)
return reaction
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> TextLabels:
model = TextLabels(
api_client_id=self.api_client.id,
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_post_id:
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
model.post_id = text_labels.post_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
+1
View File
@@ -5,6 +5,7 @@ numpy==1.22.4
psycopg2-binary==2.9.5
pydantic==1.9.1
python-dotenv==0.21.0
scipy==1.8.1
SQLAlchemy==1.4.41
sqlmodel==0.0.8
starlette==0.22.0
@@ -204,3 +204,51 @@ AnyInteraction = Union[
PostRating,
PostRanking,
]
class TextLabel(str, enum.Enum):
"""A label for a piece of text."""
spam = "spam"
violence = "violence"
sexual_content = "sexual_content"
toxicity = "toxicity"
political_content = "political_content"
humor = "humor"
sarcasm = "sarcasm"
hate_speech = "hate_speech"
profanity = "profanity"
ad_hominem = "ad_hominem"
insult = "insult"
threat = "threat"
aggressive = "aggressive"
misleading = "misleading"
helpful = "helpful"
formal = "formal"
cringe = "cringe"
creative = "creative"
beautiful = "beautiful"
informative = "informative"
based = "based"
slang = "slang"
class TextLabels(BaseModel):
"""A set of labels for a piece of text."""
text: str
labels: dict[TextLabel, float]
post_id: str | None = None
@property
def has_post_id(self) -> bool:
"""Whether this TextLabels has a post_id."""
return bool(self.post_id)
# check that each label value is between 0 and 1
@pydantic.validator("labels")
def check_label_values(cls, v):
for key, value in v.items():
if not (0 <= value <= 1):
raise ValueError(f"Label values must be between 0 and 1, got {value} for {key}.")
return v
+7
View File
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Return the current utc date and time with tzinfo set to UTC."""
return datetime.now(timezone.utc)
@@ -9,6 +9,7 @@ services:
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
@@ -16,6 +16,7 @@ services:
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: ocgpt_website
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
import numpy as np
from scipy import log2
from scipy.integrate import nquad
from scipy.special import gammaln, psi
from scipy.stats import dirichlet
def make_range(*x):
"""
constructs leftover values for the simplex given the first k entries
(0,x_k) = 1-(x_1+...+x_(k-1))
"""
return (0, max(0, 1 - sum(x)))
def relative_entropy(p, q):
"""
relative entropy of the two given dirichlet distributions
"""
def tmp(*x):
"""
First adds the last always forced entry to the input (the last x_last = 1-(x_1+...+x_(N)) )
Then computes the relative entropy of posterior and prior for that datapoint
"""
x_new = np.append(x, 1 - sum(x))
return p(x_new) * log2(p(x_new) / q(x_new))
return tmp
def naive_monte_carlo_integral(fun, dim, samples=10_000_000):
s = np.random.rand(dim - 1, samples)
s = np.sort(np.concatenate((np.zeros((1, samples)), s, np.ones((1, samples)))), 0)
# print(s)
pos = np.diff(s, axis=0)
# print(pos)
res = fun(pos)
return np.mean(res)
def analytic_solution(a_post, a_prior):
"""
Analytic solution to the KL-divergence between two dirichlet distributions.
Proof is in the Notion design doc.
"""
post_sum = np.sum(a_post)
prior_sum = np.sum(a_prior)
info = (
gammaln(post_sum)
- gammaln(prior_sum)
- np.sum(gammaln(a_post))
+ np.sum(gammaln(a_prior))
- np.sum((a_post - a_prior) * (psi(a_post) - psi(post_sum)))
)
return info
def infogain(a_post, a_prior):
raise (
"""For the love of good don't use this:
it's insanely poorly conditioned, the worst numerical code I have ever written
and it's slow as molasses. Use the analytic solution instead.
Maybe remove
"""
)
args = len(a_prior)
p = dirichlet(a_post).pdf
q = dirichlet(a_prior).pdf
(info, _) = nquad(relative_entropy(p, q), [make_range for _ in range(args - 1)], opts={"epsabs": 1e-8})
# info = naive_monte_carlo_integral(relative_entropy(p,q), len(a_post))
return info
def uniform_expected_infogain(a_prior):
mean_weight = dirichlet.mean(a_prior)
print("weight", mean_weight)
results = []
for i, w in enumerate(mean_weight):
a_post = a_prior.copy()
a_post[i] = a_post[i] + 1
results.append(w * analytic_solution(a_post, a_prior))
return np.sum(results)
if __name__ == "__main__":
a_prior = np.array([1, 1, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
a_post = np.array([1, 1, 20, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print("algebraic", analytic_solution(a_post, a_prior))
# print("raw",infogain(a_post, a_prior))
print("large infogain", uniform_expected_infogain(a_prior))
print("post infogain", uniform_expected_infogain(a_post))
# a_prior = np.array([1,1,1000])
# print("small infogain",uniform_expected_infogain(a_prior))
+2 -2
View File
@@ -68,7 +68,7 @@ def get_winner(pairs):
def get_ranking(pairs):
"""
Abuses concordance property to get a (not necessarily unqiue) ranking.
The lack of uniqueness is due to the potential existance of multiple
The lack of uniqueness is due to the potential existence of multiple
equally ranked winners. We have to pick one, which is where
the non-uniqueness comes from
"""
@@ -99,7 +99,7 @@ def ranked_pairs(ranks: List[List[int]]):
tallies = tallies - tallies.T
# print(tallies)
# note: the resulting tally matrix should be skew-symmetric
# order by strenght of victory (using tideman's original method, don't think it would make a difference for us)
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
sorted_majorities = []
for i in range(len(ranks[0])):
for j in range(len(ranks[i])):
+183
View File
@@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
from dataclasses import dataclass, replace
from typing import Any
import numpy as np
import numpy.typing as npt
from scipy.stats import kendalltau
@dataclass
class Voter:
"""
Represents a single voter.
This tabulates the number of good votes, total votes,
and points.
We only put well-behaved people on the scoreboard and filter out the badly behaved ones
"""
uid: Any
num_votes: int
num_good_votes: int
num_prompts: int
num_good_prompts: int
num_rankings: int
num_good_rankings: int
#####################
voting_points: int
prompt_points: int
ranking_points: int
def voter_quality(self):
return self.num_good_votes / self.num_votes
def rank_quality(self):
return self.num_good_rankings / self.num_rankings
def prompt_quality(self):
return self.num_good_prompts / self.num_prompts
def is_well_behaved(self, threshhold_vote, threshhold_prompt, threshhold_rank):
return (
self.voter_quality() > threshhold_vote
and self.prompt_quality() > threshhold_prompt
and self.rank_quality() > threshhold_rank
)
def total_points(self, voting_weight, prompt_weight, ranking_weight):
return (
voting_weight * self.voting_points
+ prompt_weight * self.prompt_points
+ ranking_weight * self.ranking_points
)
def score_update_votes(new_vote: int, consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the new "quality score" and points for a voter,
after that voter cast a vote on a question.
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
The consensus is the array of all votes cast by all voters for that question
We then update the voter data using the new information
Parameters:
new_vote (int): the index of the vote cast by the voter
consensus (ArrayLike): all votes cast for this question
voter_data (Voter): a "Voter" object that represents the person casting the "new_vote"
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
# since 100 is the lowest, 300 the highest and 200 the middle value
consensus_ranking = np.argsort(np.argsort(consensus))
new_points = consensus_ranking[new_vote] + voter_data.voting_points
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
# it's a good vote
new_good_votes = int(consensus_ranking[new_vote] > (len(consensus) - 1) / 2) + voter_data.num_good_votes
new_num_votes = voter_data.num_votes + 1
return replace(voter_data, num_votes=new_num_votes, num_good_votes=new_good_votes, voting_points=new_points)
def score_update_prompts(consensus: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the gain of points for a given prompt's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
Parameters:
consensus (ArrayLike): all votes cast for this question
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
# produces the ranking of votes, e.g. for [100,300,200] it returns [0, 2, 1],
# since 100 is the lowest, 300 the highest and 200 the middle value
consensus_ranking = np.arange(len(consensus)) - len(consensus) // 2 + 1
delta_votes = np.sum(consensus_ranking * consensus)
new_points = delta_votes + voter_data.prompt_points
# we need to correct for 0 indexing, if you are closer to "right" than "wrong" of the conensus,
# it's a good vote
new_good_prompts = int(delta_votes > 0) + voter_data.num_good_prompts
new_num_prompts = voter_data.num_prompts + 1
return replace(
voter_data,
num_prompts=new_num_prompts,
num_good_prompts=new_good_prompts,
prompt_points=new_points,
)
def score_update_ranking(user_ranking: npt.ArrayLike, consensus_ranking: npt.ArrayLike, voter_data: Voter) -> Voter:
"""
This function returns the gain of points for a given ranking's votes
This function is only to be run when archiving a question
i.e. the question has had sufficiently many votes, or we cann't get more than "K" bits of information
we use the bubble-sort distance (or "kendall-tau" distance) to compare the two rankings
we use this over spearman correlation since:
"[Kendall's τ] approaches a normal distribution more rapidly than ρ, as N, the sample size, increases;
and τ is also more tractable mathematically, particularly when ties are present"
Gilpin, A. R. (1993). Table for conversion of Kendall's Tau to Spearman's
Rho within the context measures of magnitude of effect for meta-analysis
Further in
"research design and statistical analyses, second edition, 2003"
the authors note that at least from an significance test POV they will yield the same p-values
Parameters:
user_ranking (ArrayLike): ranking produced by the user
consensus (ArrayLike): ranking produced after running the voting algorithm to merge into the consensus ranking
voter_data (Voter): a "Voter" object that represents the person that wrote the prompt
Returns:
updated_voter (Voter): the new "quality score" and points for the voter
"""
bubble_sort_distance, p_value = kendalltau(user_ranking, consensus_ranking)
# normalize kendall-tau from [-1,1] into [0,1] range
bubble_sort_distance = (1 + bubble_sort_distance) / 2
new_points = bubble_sort_distance + voter_data.ranking_points
new_good_rankings = int(bubble_sort_distance > 0.5) + voter_data.num_good_rankings
new_num_rankings = voter_data.num_rankings + 1
return replace(
voter_data,
num_rankings=new_num_rankings,
num_good_rankings=new_good_rankings,
ranking_points=new_points,
)
if __name__ == "__main__":
demo_voter = Voter(
"abc",
num_votes=10,
num_good_votes=2,
num_prompts=10,
num_good_prompts=2,
num_rankings=10,
num_good_rankings=2,
voting_points=6,
prompt_points=0,
ranking_points=0,
)
new_vote = 3
consensus = np.array([200, 300, 100, 500])
print(demo_voter)
print("best vote ", score_update_votes(new_vote, consensus, demo_voter))
new_vote = 2
print("worst vote ", score_update_votes(new_vote, consensus, demo_voter))
new_vote = 1
print("medium vote ", score_update_votes(new_vote, consensus, demo_voter))
print("prompt writer", score_update_prompts(consensus, demo_voter))
print("best rank ", score_update_ranking(np.array([0, 2, 1]), np.array([0, 2, 1]), demo_voter))
print("medium rank ", score_update_ranking(np.array([2, 0, 1]), np.array([0, 2, 1]), demo_voter))
print("worst rank ", score_update_ranking(np.array([1, 0, 2]), np.array([0, 2, 1]), demo_voter))
+1 -1
View File
@@ -5,7 +5,7 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5433/ocgpt_website
FASTAPI_URL=http://localhost:8080
FASTAPI_KEY=1234
# A dev Auth Secret. Can be exposed if we never use this publically.
# A dev Auth Secret. Can be exposed if we never use this publicly.
NEXTAUTH_SECRET=O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
# The SMTP host and port found by running the jobs in /scripts/frontend-development/docker-compose.yaml
+8
View File
@@ -63,6 +63,14 @@ If you're doing active development we suggest the following workflow:
navigate to `http://localhost:1080`. Check the email listed and click the
log in link. You're now logged in and authenticated.
### Using debug user credentials
Whenever the website runs in development mode, you can use the debug credentials provider to log in without fancy emails or OAuth.
1. Development mode is automatically active when you start the website with `npm run dev`.
1. Use the `Login` button in the top right to go to the login page.
1. You should see a section for debug credentials. Enter any username you wish, you will be logged in as that user.
## Code Layout
### React Code
+2 -2
View File
@@ -12,7 +12,7 @@ export function Avatar() {
return <></>;
}
if (session && session.user) {
const email = session.user.email;
const displayName = session.user.name || session.user.email;
const accountOptions = [
{
name: "Account Settings",
@@ -35,7 +35,7 @@ export function Avatar() {
height="40"
className="rounded-full"
></Image>
<p className="hidden lg:flex">{email}</p>
<p className="hidden lg:flex">{displayName}</p>
{/* Will be changed to username once it is implemented */}
</div>
</Popover.Button>
+19
View File
@@ -0,0 +1,19 @@
import clsx from "clsx";
export const Button = (
props: React.DetailedHTMLProps<React.ButtonHTMLAttributes<HTMLButtonElement>, HTMLButtonElement>
) => {
const { className, children, ...rest } = props;
return (
<button
type="button"
className={clsx(
"inline-flex items-center rounded-md border border-transparent px-4 py-2 text-sm font-medium focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2",
className
)}
{...rest}
>
{children}
</button>
);
};
+1 -7
View File
@@ -12,13 +12,7 @@ export function Footer() {
<div>
<div className="flex items-center text-gray-900">
<Link href="/" aria-label="Home" className="flex items-center">
<Image
src="/images/logos/CHAT-THOUGHT-LOGO.svg"
className="mx-auto object-fill"
width="50"
height="50"
alt="logo"
/>
<Image src="/images/logos/logo.svg" className="mx-auto object-fill" width="50" height="50" alt="logo" />
</Link>
<div className="ml-4">
+18
View File
@@ -0,0 +1,18 @@
export interface Message {
text: string;
is_assistant: boolean;
}
const getColor = (isAssistant: boolean) => (isAssistant ? "bg-slate-800" : "bg-sky-900");
export const Messages = ({ messages }: { messages: Message[] }) => {
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
return (
<div key={i + text} className={`${getColor(is_assistant)} p-4 my-1 rounded-xl text-white whitespace-pre-wrap`}>
{text}
</div>
);
});
// Maybe also show a legend of the colors?
return <>{items}</>;
};
+12
View File
@@ -0,0 +1,12 @@
const RankItem = ({ username, score }) => {
return (
<div className="flex flex-row justify-between p-6 border-2 border-slate-100 text-left font-semibold hover:bg-sky-50">
<div>1</div>
<div>@username</div>
<div>20.5</div>
<div>gold</div>
</div>
);
};
export default RankItem;
@@ -0,0 +1,39 @@
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
export type OptionProps = {
img: string;
alt: string;
title: string;
link: string;
};
export const TaskOption = (props: OptionProps) => {
const { alt, img, title, link } = props;
return (
<Link href={link}>
<Card
maxW="300"
minW="300"
minH="300"
maxH="300"
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
>
<CardBody width="full" height="full">
<Flex direction="column" alignItems="center" justifyContent="center">
<Image src={img} alt={alt} width={200} height={200} />
<Heading
mt={-10}
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
textAlign="center"
fontSize="3xl"
>
{title}
</Heading>
</Flex>
</CardBody>
</Card>
</Link>
);
};
@@ -0,0 +1,23 @@
import { Divider, Flex, Heading } from "@chakra-ui/react";
import React from "react";
export type TaskOptionsProps = {
title: string;
children: JSX.Element | JSX.Element[];
};
export const TaskOptions = (props: TaskOptionsProps) => {
const { title, children } = props;
return (
<Flex gap={10} wrap="wrap" justifyContent="center">
<Heading
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
fontSize="5xl"
>
{title}
</Heading>
<Divider mt={-8} />
{children}
</Flex>
);
};
@@ -0,0 +1,29 @@
import React from "react";
import { TaskOptions } from "./TaskOptions";
import { Flex } from "@chakra-ui/react";
import { TaskOption } from "./TaskOption";
export const TaskSelection = () => {
return (
<Flex gap={10} wrap="wrap" justifyContent="space-evenly" width="full" height="full" alignItems={"center"}>
<TaskOptions key="create" title="Create">
<TaskOption
alt="Summarize Stories"
img="/images/logos/logo.svg"
title="Summarize stories"
link="/summarize/story"
/>
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
<TaskOption
alt="Reply as Assistant"
img="/images/logos/logo.svg"
title="Reply as Assistant"
link="/create/assistant_reply"
/>
</TaskOptions>
<TaskOptions key="evaluate" title="Evaluate">
<TaskOption alt="Rate Prompts" img="/images/logos/logo.svg" title="Rate Prompts" link="/grading/grade-output" />
</TaskOptions>
</Flex>
);
};
@@ -0,0 +1,3 @@
export { TaskSelection } from "./TaskSelection";
export { TaskOptions } from "./TaskOptions";
export { TaskOption } from "./TaskOption";
+14
View File
@@ -0,0 +1,14 @@
export const TwoColumns = ({ children }: { children: React.ReactNode[] }) => {
if (!Array.isArray(children) || children.length !== 2) {
throw new Error("TwoColumns expects 2 children");
}
const [first, second] = children;
return (
<section className="mb-8 lt-lg:mb-12 grid lg:gap-x-12 lg:grid-cols-2">
<div className="rounded-lg shadow-lg h-full block bg-white p-6">{first}</div>
<div className="rounded-lg shadow-lg h-full block bg-white p-6 mt-6 lg:mt-0">{second}</div>
</section>
);
};
@@ -2,6 +2,7 @@ import type { AuthOptions } from "next-auth";
import NextAuth from "next-auth";
import DiscordProvider from "next-auth/providers/discord";
import EmailProvider from "next-auth/providers/email";
import CredentialsProvider from "next-auth/providers/credentials";
import { PrismaAdapter } from "@next-auth/prisma-adapter";
import prisma from "src/lib/prismadb";
@@ -32,6 +33,23 @@ if (process.env.DISCORD_CLIENT_ID) {
);
}
if (process.env.NODE_ENV === "development") {
providers.push(
CredentialsProvider({
name: "Debug Credentials",
credentials: {
username: { label: "Username", type: "text" },
},
async authorize(credentials) {
return {
id: credentials.username,
name: credentials.username,
};
},
})
);
}
export const authOptions: AuthOptions = {
// Ensure we can store user data in a database.
adapter: PrismaAdapter(prisma),
+32 -9
View File
@@ -1,6 +1,7 @@
import { Button, Input, Stack } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { FaDiscord, FaEnvelope, FaBug, FaGithub } from "react-icons/fa";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
import { useRef } from "react";
import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
@@ -8,19 +9,28 @@ import { FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { AuthLayout } from "src/components/AuthLayout";
export default function Signin({ csrfToken, providers }) {
const { discord, email, github } = providers;
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
const debugUsernameEl = useRef(null);
const signinWithDiscord = () => {
signIn(discord.id, { callbackUrl: "/" });
};
const signinWithEmail = () => {
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
signIn(email.id, { callbackUrl: "/", email: emailEl.current.value });
};
const signinWithGithub = () => {
signIn(github.id, { callbackUrl: "/" });
};
function signinWithDebugCredentials(ev: React.FormEvent) {
ev.preventDefault();
signIn(credentials.id, { callbackUrl: "/", username: debugUsernameEl.current.value });
}
return (
<>
<Head>
@@ -28,14 +38,27 @@ export default function Signin({ csrfToken, providers }) {
<meta name="Sign Up" content="Sign up to access Open Assistant" />
</Head>
<AuthLayout>
<Stack spacing="2">
<Stack spacing="6">
{credentials && (
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-200 rounded-md p-4 relative">
<span className="text-orange-600 absolute -top-3 left-5 bg-white px-1">For Debugging Only</span>
<Stack>
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
<Button size={"lg"} leftIcon={<FaBug />} colorScheme="gray" type="submit">
Continue with Debug User
</Button>
</Stack>
</form>
)}
{email && (
<Stack>
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray" onClick={signinWithEmail}>
Continue with Email
</Button>
</Stack>
<form onSubmit={signinWithEmail}>
<Stack>
<Input variant="outline" size="lg" placeholder="Email Address" ref={emailEl} />
<Button size={"lg"} leftIcon={<FaEnvelope />} colorScheme="gray">
Continue with Email
</Button>
</Stack>
</form>
)}
{discord && (
<Button
@@ -0,0 +1,83 @@
import { Textarea } from "@chakra-ui/react";
import { useRef, useState } from "react";
import useSWRMutation from "swr/mutation";
import useSWRImmutable from "swr/immutable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { Messages } from "src/components/Messages";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
const AssistantReply = () => {
const [tasks, setTasks] = useState([]);
const inputRef = useRef<HTMLTextAreaElement>(null);
const { isLoading } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
onSuccess: (data) => {
console.log(data);
setTasks([data]);
},
});
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
id: task.id,
update_type: "text_reply_to_post",
content: {
text,
},
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (tasks.length == 0) {
return <div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
}
const task = tasks[0].task;
return (
<div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
<TwoColumns>
<>
<h5 className="text-lg font-semibold">Reply as the assistant</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} />
</>
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumns>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
<b>Prompt</b>
<span>{tasks[0].id}</span>
<b>Output</b>
<span>Submit your answer</span>
</div>
<div className="flex justify-center ml-auto">
<Button className="mr-2 bg-indigo-100 text-indigo-700 hover:bg-indigo-200">Skip</Button>
<Button
onClick={() => submitResponse(tasks[0])}
className="bg-indigo-600 text-white shadow-sm hover:bg-indigo-700"
>
Submit
</Button>
</div>
</section>
</div>
);
};
export default AssistantReply;
+84
View File
@@ -0,0 +1,84 @@
import { Textarea } from "@chakra-ui/react";
import { useRef, useState } from "react";
import useSWRMutation from "swr/mutation";
import useSWRImmutable from "swr/immutable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import { Messages } from "src/components/Messages";
import { TwoColumns } from "src/components/TwoColumns";
import { Button } from "src/components/Button";
const UserReply = () => {
const [tasks, setTasks] = useState([]);
const inputRef = useRef<HTMLTextAreaElement>(null);
const { isLoading } = useSWRImmutable("/api/new_task/user_reply", fetcher, {
onSuccess: (data) => {
console.log(data);
setTasks([data]);
},
});
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
id: task.id,
update_type: "text_reply_to_post",
content: {
text,
},
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (tasks.length == 0) {
return <div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
}
const task = tasks[0].task;
return (
<div className="p-6 h-full mx-auto bg-slate-100 text-gray-800">
<TwoColumns>
<>
<h5 className="text-lg font-semibold">Reply as a user</h5>
<p className="text-lg py-1">Given the following conversation, provide an adequate reply</p>
<Messages messages={task.conversation.messages} />
{task.hint && <p className="text-lg py-1">Hint: {task.hint}</p>}
</>
<Textarea name="reply" placeholder="Reply..." ref={inputRef} />
</TwoColumns>
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<div className="grid grid-cols-[min-content_auto] gap-x-2 text-gray-700">
<b>Prompt</b>
<span>{tasks[0].id}</span>
<b>Output</b>
<span>Submit your answer</span>
</div>
<div className="flex justify-center ml-auto">
<Button className="mr-2 bg-indigo-100 text-indigo-700 hover:bg-indigo-200">Skip</Button>
<Button
onClick={() => submitResponse(tasks[0])}
className="bg-indigo-600 text-white shadow-sm hover:bg-indigo-700"
>
Submit
</Button>
</div>
</section>
</div>
);
};
export default UserReply;
+3 -10
View File
@@ -1,16 +1,11 @@
import { useSession } from "next-auth/react";
import Head from "next/head";
import Link from "next/link";
import { Button, Input, Stack } from "@chakra-ui/react";
import { CallToAction } from "../components/CallToAction";
import { Faq } from "../components/Faq";
import { Footer } from "../components/Footer";
import { Header } from "../components/Header";
import { Hero } from "../components/Hero";
import styles from "../styles/Home.module.css";
import { TaskSelection } from "../components/TaskSelection";
export default function Home() {
const { data: session } = useSession();
@@ -46,10 +41,8 @@ export default function Home() {
/>
</Head>
<Header />
<main className="h-3/4 z-0 bg-white flex items-center justify-center">
<Button size="lg" colorScheme="blue" className="drop-shadow">
<Link href="/grading/grade-output">Rate a prompt and output now</Link>
</Button>
<main className="h-3/4 m-20 z-0 bg-white flex flex-col items-center justify-center gap-2">
<TaskSelection />
</main>
<Footer />
</>
@@ -0,0 +1,54 @@
import RankItem from "@/components/RankItem";
import { BarsArrowUpIcon, BarsArrowDownIcon } from "@heroicons/react/24/solid";
import Image from "next/image";
import { HiBarsArrowUp, HiBarsArrowDown } from "react-icons/hi2";
const LeaderBoard = () => {
const PlaceHolderProps = { username: "test_user", score: 10 };
return (
<div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">
<div className="flex flex-col">
<div className="rounded-lg shadow-lg h-full block bg-white">
<div className="p-8">
<h5 className="text-2xl font-bold">LeaderBoard</h5>
</div>
<div className="flex flex-row justify-between px-6 py-3 font-semibold text-md">
<div className="flex flex-row items-center justify-center space-x-2">
<div>
<p>Rank</p>
</div>
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
</div>
</div>
<div className="flex flex-row items-center justify-center space-x-2">
<div>
<p>User</p>
</div>
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
</div>
</div>
<div className="flex flex-row items-center justify-center space-x-2">
<div>
<p>Score</p>
</div>
<div className="mt-2 text-slate-400 hover:text-sky-400 hover:cursor-pointer">
<HiBarsArrowDown className="w-6 h-6 text-inherit"></HiBarsArrowDown>
</div>
</div>
<div className="flex flex-row items-center justify-center space-x-2">
<div>
<p>Medal</p>
</div>
</div>
</div>
{/* leaderboard items */}
<RankItem {...PlaceHolderProps}></RankItem>
</div>
</div>
</div>
);
};
export default LeaderBoard;
+118
View File
@@ -0,0 +1,118 @@
// TODO(#65): Unify and simplify the task paths
import { Textarea } from "@chakra-ui/react";
import { useRef, useState } from "react";
import useSWRMutation from "swr/mutation";
import useSWRImmutable from "swr/immutable";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
const SummarizeStory = () => {
// Use an array of tasks that record the sequence of steps until a task is
// deemed complete.
const [tasks, setTasks] = useState([]);
const inputRef = useRef<HTMLTextAreaElement>(null);
// Fetch the very fist task. We can ignore everything except isLoading
// because the onSuccess handler will update `tasks` when ready.
const { isLoading } = useSWRImmutable("/api/new_task/summarize_story", fetcher, {
onSuccess: (data) => {
console.log(data);
setTasks([data]);
},
});
// Every time we submit an answer to the latest task, let the backend handle
// all the interactions then add the resulting task to the queue. This ends
// when we hit the done task.
const { trigger, isMutating } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
// This is the more efficient way to update a react state array.
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
// Trigger a mutation that updates the current task. We should probably
// signal somewhere that this interaction is being processed.
const submitResponse = (task: { id: string }) => {
const text = inputRef.current.value.trim();
trigger({
id: task.id,
content: {
update_type: "text_reply_to_post",
text,
},
});
};
/**
* TODO: Make this a nicer loading screen.
*/
if (tasks.length == 0) {
return <div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">Loading...</div>;
}
return (
<div className=" p-6 h-full mx-auto bg-slate-100 text-gray-800">
{/* Instrunction and Output panels */}
<section className="mb-8 lt-lg:mb-12 ">
<div className="grid lg:gap-x-12 lg:grid-cols-2">
{/* Instruction panel */}
<div className="rounded-lg shadow-lg h-full block bg-white">
<div className="p-6">
<h5 className="text-lg font-semibold">Instruction</h5>
<p className="text-lg py-1">Summarize the following story</p>
<div className="bg-slate-800 p-6 rounded-xl text-white whitespace-pre-wrap">{tasks[0].task.story}</div>
</div>
</div>
{/* Output panel */}
<div className="mt-6 lg:mt-0 rounded-lg shadow-lg h-full block bg-white">
<div className="flex justify-center p-6">
<Textarea name="summary" placeholder="Summary" ref={inputRef} />
</div>
</div>
</div>
</section>
{/* Info & controls */}
<section className="mb-8 p-4 rounded-lg shadow-lg bg-white flex flex-row justify-items-stretch ">
<div className="flex flex-col justify-self-start text-gray-700">
<div>
<span>
<b>Prompt</b>
</span>
<span className="ml-2">{tasks[0].id}</span>
</div>
<div>
<span>
<b>Output</b>
</span>
<span className="ml-2">Submit your answer</span>
</div>
</div>
{/* Skip / Submit controls */}
<div className="flex justify-center ml-auto">
<button
type="button"
className="mr-2 inline-flex items-center rounded-md border border-transparent bg-indigo-100 px-4 py-2 text-sm font-medium text-indigo-700 hover:bg-indigo-200 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2"
>
Skip
</button>
<button
type="button"
onClick={() => submitResponse(tasks[0])}
className="inline-flex items-center rounded-md border border-transparent bg-indigo-600 px-4 py-2 text-sm font-medium text-white shadow-sm hover:bg-indigo-700 focus:outline-none focus:ring-2 focus:ring-indigo-500 focus:ring-offset-2"
>
Submit
</button>
</div>
</section>
</div>
);
};
export default SummarizeStory;