Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Andrew Maguire
2023-01-09 17:16:31 +00:00
92 changed files with 2340 additions and 939 deletions
+3
View File
@@ -8,6 +8,9 @@ on:
- ".github/workflows/deploy-docs-site.yaml"
- "docs/**"
pull_request:
paths:
- ".github/workflows/deploy-docs-site.yaml"
- "docs/**"
jobs:
deploy:
+9
View File
@@ -16,3 +16,12 @@ jobs:
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0
- name: Post PR comment on failure
if: failure() && github.event_name == 'pull_request'
uses: peter-evans/create-or-update-comment@v2
with:
issue-number: ${{ github.event.pull_request.number }}
body: |
:x: **pre-commit** failed.
Please run `pre-commit run --all-files` locally and commit the changes.
Find more information in the repository's CONTRIBUTING.md
+9
View File
@@ -29,6 +29,15 @@ jobs:
deploy-dev:
needs: [build-backend, build-web, build-bot]
runs-on: ubuntu-latest
env:
WEB_ADMIN_USERS: ${{ secrets.DEV_WEB_ADMIN_USERS }}
WEB_DISCORD_CLIENT_ID: ${{ secrets.DEV_WEB_DISCORD_CLIENT_ID }}
WEB_DISCORD_CLIENT_SECRET: ${{ secrets.DEV_WEB_DISCORD_CLIENT_SECRET }}
WEB_EMAIL_SERVER_HOST: ${{ secrets.DEV_WEB_EMAIL_SERVER_HOST }}
WEB_EMAIL_SERVER_PASSWORD: ${{ secrets.DEV_WEB_EMAIL_SERVER_PASSWORD }}
WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
WEB_NEXTAUTH_SECRET: ${{ secrets.DEV_WEB_NEXTAUTH_SECRET }}
steps:
- name: Checkout
uses: actions/checkout@v2
+4 -2
View File
@@ -96,8 +96,10 @@ The website is built using Next.js and is in the `website` folder.
### Pre-commit
Install `pre-commit` and run `pre-commit install` to install the pre-commit
hooks.
We are using `pre-commit` to enforce code style and formatting.
Install `pre-commit` from [its website](https://pre-commit.com) and run
`pre-commit install` to install the pre-commit hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
+7
View File
@@ -0,0 +1,7 @@
To test the ansible playbook on localhost run
`ansible-playbook -i test.inventory.ini dev.yaml`.\
In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\
Point Redis Insights to the Redis database by visiting localhost:8001 in a
browser and select "I already have a database" followed by "Connect to a Redis
Database".\
For host, port and name fill in `oasst-redis`, `6379` and `redis`.
+53 -15
View File
@@ -10,6 +10,39 @@
state: present
driver: bridge
- name: Copy redis.conf to managed node
ansible.builtin.copy:
src: ./redis.conf
dest: ./redis.conf
- name: Set up Redis
community.docker.docker_container:
name: oasst-redis
image: redis
state: started
restart_policy: always
network_mode: oasst
ports:
- 6379:6379
healthcheck:
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
interval: 2s
timeout: 2s
retries: 10
command: redis-server /usr/local/etc/redis/redis.conf
volumes:
- "./redis.conf:/usr/local/etc/redis/redis.conf"
- name: Set up Redis Insights
community.docker.docker_container:
name: oasst-redis-insights
image: redislabs/redisinsight:latest
state: started
restart_policy: always
network_mode: oasst
ports:
- 8001:8001
- name: Create postgres containers
community.docker.docker_container:
name: "{{ item.name }}"
@@ -32,14 +65,6 @@
- name: oasst-postgres
- name: oasst-postgres-web
- name: Set up maildev
community.docker.docker_container:
name: oasst-maildev
image: maildev/maildev
state: started
restart_policy: always
network_mode: oasst
- name: Run the oasst oasst-backend
community.docker.docker_container:
name: oasst-backend
@@ -51,6 +76,7 @@
network_mode: oasst
env:
POSTGRES_HOST: oasst-postgres
REDIS_HOST: oasst-redis
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
@@ -68,15 +94,27 @@
restart_policy: always
network_mode: oasst
env:
FASTAPI_URL: http://oasst-backend:8080
FASTAPI_KEY: "123"
ADMIN_USERS: "{{ lookup('ansible.builtin.env', 'WEB_ADMIN_USERS') }}"
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
EMAIL_SERVER_HOST: oasst-maildev
EMAIL_SERVER_PORT: "25"
EMAIL_FROM: info@example.com
NEXTAUTH_URL: http://web.dev.open-assistant.io/
DEBUG_LOGIN: "true"
DISCORD_CLIENT_ID:
"{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_ID') }}"
DISCORD_CLIENT_SECRET:
"{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_SECRET') }}"
EMAIL_FROM: open-assistent@laion.ai
EMAIL_SERVER_HOST:
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_HOST') }}"
EMAIL_SERVER_PASSWORD:
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PASSWORD') }}"
EMAIL_SERVER_PORT:
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PORT') }}"
EMAIL_SERVER_USER:
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
FASTAPI_URL: http://oasst-backend:8080
FASTAPI_KEY: "1234"
NEXTAUTH_SECRET:
"{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
NEXTAUTH_URL: http://web.dev.open-assistant.io/
ports:
- 3000:3000
command: bash wait-for-postgres.sh node server.js
+2
View File
@@ -0,0 +1,2 @@
maxmemory 100mb
maxmemory-policy allkeys-lru
+2
View File
@@ -0,0 +1,2 @@
[test]
dev ansible_connection=local
@@ -20,4 +20,4 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_column("api_client", "frontend_id")
op.drop_column("api_client", "frontend_type")
+11 -6
View File
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
ur = UserRepository(db=db, api_client=api_client)
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
for cmsg in conversation_messages
]
)
task = pr.store_task(
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
"""
Get a message by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
return utils.prepare_message(message)
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
Get a conversation from the tree root and up to the message with given frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return utils.prepare_message_list(messages)
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
+8 -7
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@@ -11,15 +12,15 @@ router = APIRouter()
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="assistant")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
return pr.get_user_leaderboard(role="prompter")
) -> LeaderboardStats:
ur = UserRepository(db, api_client)
return ur.get_user_leaderboard(role="prompter")
+9 -9
View File
@@ -29,7 +29,7 @@ def query_messages(
"""
Query messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -51,7 +51,7 @@ def get_message(
"""
Get a message by its internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@@ -64,7 +64,7 @@ def get_conv(
Get a conversation from the tree root and up to the message with given internal ID.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@@ -76,7 +76,7 @@ def get_tree(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -89,7 +89,7 @@ def get_children(
"""
Get all messages belonging to the same message tree.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@@ -101,7 +101,7 @@ def get_descendants(
"""
Get a subtree which starts with this message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -114,7 +114,7 @@ def get_longest_conv(
"""
Get the longest conversation from the tree of the message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -127,7 +127,7 @@ def get_max_children(
"""
Get message with the most children from the tree of the provided message.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -137,5 +137,5 @@ def get_max_children(
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)
+1 -1
View File
@@ -13,5 +13,5 @@ def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
return pr.get_stats()
+10 -10
View File
@@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -190,9 +190,9 @@ def request_task(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, request.user)
pr = PromptRepository(db, api_client, client_user=request.user)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -217,11 +217,11 @@ def tasks_acknowledge(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
@@ -245,8 +245,8 @@ def tasks_acknowledge_failure(
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.acknowledge_task_failure(task_id)
pr = PromptRepository(db, api_client)
pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -265,7 +265,7 @@ def tasks_interaction(
api_client = deps.api_auth(api_key, db)
try:
pr = PromptRepository(db, api_client, user=interaction.user)
pr = PromptRepository(db, api_client, client_user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToMessage:
@@ -323,6 +323,6 @@ def close_collective_task(
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.message_id)
tr = TaskRepository(db, api_client)
tr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
+1 -1
View File
@@ -25,7 +25,7 @@ def label_text(
try:
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, user=text_labels.user)
pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
+2 -2
View File
@@ -29,7 +29,7 @@ def query_user_messages(
"""
Query user messages.
"""
pr = PromptRepository(db, api_client, user=None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
@@ -48,6 +48,6 @@ def query_user_messages(
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
@@ -6,27 +6,56 @@ import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
# The types of States a message tree can have.
class States(str, Enum):
"""States of the Open-Assistant message tree state machine."""
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
"""In this state the message tree consists only of a single inital prompt root node.
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
`aborted_low_grade`."""
class States(Enum):
INITIAL = "initial"
BREEDING_PHASE = "breeding_phase"
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
are handed out to check if the quality of the replies surpasses the minimum acceptable
quality.
When the required number of messages passing the initial labelling-quality check has been
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
are received the tree can also enter the `aborted_low_grade` state."""
RANKING_PHASE = "ranking_phase"
"""The tree has been successfully populated with the desired number of messages. Ranking
tasks are now handed out for all nodes with more than one child."""
READY_FOR_SCORING = "ready_for_scoring"
CHILDREN_SCORED = "children_scored"
FINAL = "final"
"""Required ranking responses have been collected and the scoring algorithm can now
compute the aggergated ranking scores that will appear in the dataset."""
READY_FOR_EXPORT = "ready_for_export"
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
exported as part of an Open-Assistant message tree dataset."""
SCORING_FAILED = "scoring_failed"
"""An exception occured in the scoring algorithm."""
ABORTED_LOW_GRADE = "aborted_low_grade"
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
HALTED_BY_MODERATOR = "halted_by_moderator"
"""A moderator decided to manually halt the message tree construction process."""
VALID_STATES = (
States.INITIAL,
States.INITIAL_PROMPT_REVIEW,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
States.CHILDREN_SCORED,
States.FINAL,
States.READY_FOR_EXPORT,
States.ABORTED_LOW_GRADE,
)
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
+1 -1
View File
@@ -35,4 +35,4 @@ class Task(SQLModel, table=True):
@property
def expired(self) -> bool:
return self.expiry_date is not None and datetime.utcnow() < self.expiry_date
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
+58 -268
View File
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User] = None,
user_repository: Optional[UserRepository] = None,
task_repository: Optional[TaskRepository] = None,
):
self.db = db
self.api_client = api_client
self.user = self.lookup_user(user)
self.user_repository = user_repository or UserRepository(db, api_client)
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
self.task_repository = task_repository or TaskRepository(
db, api_client, client_user, user_repository=self.user_repository
)
self.journal = JournalWriter(db, api_client, self.user)
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def validate_frontend_message_id(self, message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
self.validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
self.validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
@@ -113,20 +54,48 @@ class PromptRepository:
)
return message
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
self.validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
return task
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
@@ -174,7 +143,7 @@ class PromptRepository:
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
task = self.fetch_task_by_frontend_message_id(rating.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
@@ -201,7 +170,7 @@ class PromptRepository:
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -255,142 +224,6 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def insert_message(
self,
*,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
message = Message(
id=message_id,
parent_id=parent_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
return message
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -515,28 +348,6 @@ class PromptRepository:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
@@ -728,24 +539,3 @@ class PromptRepository:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
+199
View File
@@ -0,0 +1,199 @@
from typing import Optional
from uuid import UUID
import oasst_backend.models.db_payload as db_payload
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_404_NOT_FOUND
def validate_frontend_message_id(message_id: str) -> None:
# TODO: Should it be replaced with fastapi/pydantic validation?
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
class TaskRepository:
def __init__(
self,
db: Session,
api_client: ApiClient,
client_user: Optional[protocol_schema.User],
user_repository: UserRepository,
):
self.db = db
self.api_client = api_client
self.user_repository = user_repository
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
def store_task(
self,
task: protocol_schema.Task,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
payload = db_payload.SummarizationStoryPayload(story=task.story)
case protocol_schema.RateSummaryTask:
payload = db_payload.RateSummaryPayload(
full_text=task.full_text, summary=task.summary, scale=task.scale
)
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.RankAssistantRepliesTask:
payload = db_payload.RankAssistantRepliesPayload(
type=task.type, conversation=task.conversation, replies=task.replies
)
case protocol_schema.LabelInitialPromptTask:
payload = db_payload.LabelInitialPromptPayload(
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
)
case protocol_schema.LabelPrompterReplyTask:
payload = db_payload.LabelPrompterReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case protocol_schema.LabelAssistantReplyTask:
payload = db_payload.LabelAssistantReplyPayload(
type=task.type,
message_id=task.message_id,
conversation=task.conversation,
reply=task.reply,
valid_labels=task.valid_labels,
)
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert task.id == task.id
return task
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
"""
validate_frontend_message_id(frontend_message_id)
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
task.done = True
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
id=id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return task
+64
View File
@@ -0,0 +1,64 @@
from typing import Optional
from oasst_backend.models import ApiClient, Message, User
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
class UserRepository:
def __init__(self, db: Session, api_client: ApiClient):
self.db = db
self.api_client = api_client
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
"""
Get leaderboard stats for Messages created,
separate leaderboard for prompts & assistants
"""
query = (
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
.join(User, User.id == Message.user_id, isouter=True)
.filter(Message.deleted is not True, Message.role == role)
.group_by(Message.user_id, User.username, User.display_name)
.order_by(func.count(Message.user_id).desc())
)
result = [
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
for i, j in enumerate(query.all(), start=1)
]
return LeaderboardStats(leaderboard=result)
+69 -1
View File
@@ -1,7 +1,63 @@
# General
# Research
This page lists research papers that are relevant to the project.
## Table of Contents
- Reinforcement Learning from Human Feedback
- Generating Text From Language Models
- Automatically Generating Instruction Data for Training
- Uncertainty Estimation of Language Model Outputs
## Reinforcement Learning from Human Feedback <a name="reinforcement-learning-from-human-feedback"></a>
Reinforcement Learning from Human Feedback (RLHF) is a method for fine-tuning a
generative language models based on a reward model that is learned from human
preference data. This method facilitates the learning of instruction-tuned
models, among other things.
### Learning to summarize from human feedback [[ArXiv](https://arxiv.org/pdf/2009.01325.pdf)], [[Github](https://github.com/openai/summarize-from-feedback)]
> In this work, we show that it is possible to significantly improve summary
> quality by training a model to optimize for human preferences. We collect a
> large, high-quality dataset of human comparisons between summaries, train a
> model to predict the human-preferred summary, and use that model as a reward
> function to fine-tune a summarization policy using reinforcement learning.
### Training language models to follow instructions with human feedback [[ArXiv](https://arxiv.org/pdf/2203.02155.pdf)]
> Starting with a set of labeler-written prompts and prompts submitted through
> the OpenAI API, we collect a dataset of labeler demonstrations of the desired
> model behavior, which we use to fine-tune GPT-3 using supervised learning. We
> then collect a dataset of rankings of model outputs, which we use to further
> fine-tune this supervised model using reinforcement learning from human
> feedback.
### Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback [[ArXiv](https://arxiv.org/pdf/2204.05862.pdf)]
> We apply preference modeling and reinforcement learning from human feedback
> (RLHF) to finetune language models to act as helpful and harmless assistants.
> We find this alignment training improves performance on almost all NLP
> evaluations, and is fully compatible with training for specialized skills such
> as python coding and summarization.
## Generating Text From Language Models
A language model generates output text token by token, autoregressively. The
large search space of this task requires some method of narrowing down the set
of tokens to be considered in each step. This method, in turn, has a big impact
on the quality of the resulting text.
### RANKGEN: Improving Text Generation with Large Ranking Models [[ArXiv](https://arxiv.org/pdf/2205.09726.pdf)], [[Github](https://github.com/martiansideofthemoon/rankgen)]
> Given an input sequence (or prefix), modern language models often assign high
> probabilities to output sequences that are repetitive, incoherent, or
> irrelevant to the prefix; as such, model-generated text also contains such
> artifacts. To address these issues we present RankGen, a 1.2B parameter
> encoder model for English that scores model generations given a prefix.
> RankGen can be flexibly incorporated as a scoring function in beam search and
> used to decode from any pretrained language model.
## Automatically Generating Instruction Data for Training
This line of work is about significantly reducing the need for manually
@@ -32,3 +88,15 @@ models.
> rivals the effectiveness of training on open-source manually-curated datasets,
> surpassing the performance of models such as T0++ and Tk-Instruct across
> various benchmarks.
## Uncertainty Estimation of Language Model Outputs
### Teaching models to express their uncertainty in words [[Arxiv](https://arxiv.org/pdf/2205.14334.pdf)]
> We show that a GPT-3 model can learn to express uncertainty about its own
> answers in natural language -- without use of model logits. When given a
> question, the model generates both an answer and a level of confidence (e.g.
> "90% confidence" or "high confidence"). These levels map to probabilities that
> are well calibrated. The model also remains moderately calibrated under
> distribution shift, and is sensitive to uncertainty in its own answers, rather
> than imitating human examples.
+46 -2
View File
@@ -19,7 +19,7 @@
"""
from dataclasses import dataclass
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import numpy as np
import torch
@@ -35,7 +35,7 @@ class RankGenCollator:
max_length: Optional[int] = None
max_examples: Optional[int] = None
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
@@ -193,3 +193,47 @@ class HFSummary(Dataset):
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
# optimize the format later
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
class HFDataset(Dataset):
"""
This is a base huggingface dataset which written to support the
simplest pos-neg pair format
we should do something like this for supervised datasets
"""
def __init__(
self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
) -> None:
super().__init__()
dataset = load_dataset(dataset_name, subset)
if split is not None:
dataset = dataset[split]
self.questions = {}
self.index2question = {}
for row in dataset:
question = row[question_field].strip()
pos = row[pos_answer_field]
neg = row[neg_answer_field]
if question not in self.index2question:
self.index2question[len(self.index2question)] = question
if question not in self.questions:
self.questions[question] = []
self.questions[question].append((pos.strip(), neg.strip()))
def __len__(self):
return len(self.index2question)
def __getitem__(self, index):
question = self.index2question[index]
rows = self.questions[question]
# optimize the format later
return question, rows
class GPTJSynthetic(HFDataset):
def __init__(self) -> None:
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
+11 -5
View File
@@ -1,5 +1,5 @@
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -25,7 +25,7 @@ def test_webgpt():
print(batch["input_ids"].shape)
def test_hf_quality():
def test_hf_summary_quality():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
@@ -35,6 +35,12 @@ def test_hf_quality():
print(batch["input_ids"].shape)
if __name__ == "__main__":
test_hf_quality()
# test_webgpt()
def test_gptj_dataset():
dataset = GPTJSynthetic()
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024)
print(len(dataset))
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
for batch in dataloader:
batch["input_ids"].shape
+12 -55
View File
@@ -1,29 +1,23 @@
import os
from argparse import ArgumentParser
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from models import RankGenModel
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
from rank_datasets import DataCollatorForPairRank, RankGenCollator
from torch import nn
from torch.utils.data import ConcatDataset, Dataset
from transformers import (
AdamW,
AutoModelForSequenceClassification,
DataCollator,
EvalPrediction,
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
TrainerCallback,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
os.environ["WANDB_PROJECT"] = "reward-model"
@@ -32,11 +26,6 @@ parser = ArgumentParser()
parser.add_argument("config", type=str)
@dataclass
class CustomTrainingArguments(TrainingArguments):
loss_function: str = "rank"
def compute_metrics(eval_pred):
predictions, _ = eval_pred
predictions = np.argmax(predictions, axis=1)
@@ -60,31 +49,12 @@ class RankTrainer(Trainer):
model: Union[PreTrainedModel, nn.Module] = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
loss_function: str = "rank",
**kwargs,
):
super().__init__(
model,
args,
data_collator,
train_dataset,
eval_dataset,
tokenizer,
model_init,
compute_metrics,
callbacks,
optimizers,
preprocess_logits_for_metrics,
)
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = args.loss_function
super().__init__(model, args, **kwargs)
self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss()
self.loss_function = loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
@@ -163,11 +133,10 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
args = CustomTrainingArguments(
args = TrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
warmup_steps=500,
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=training_conf["fp16"],
@@ -184,22 +153,9 @@ if __name__ == "__main__":
save_steps=1000,
report_to="wandb",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
web_dataset = WebGPT()
train, eval = train_val_dataset(web_dataset)
train_datasets.append(train)
evals["webgpt"] = eval
if "hfsummary" in training_conf["datasets"]:
sum_train = HFSummary(split="train")
train_datasets.append(sum_train)
sum_eval = HFSummary(split="valid1")
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
train, evals = get_datasets(training_conf["datasets"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
@@ -224,8 +180,9 @@ if __name__ == "__main__":
model=model,
model_name=model_name,
args=args,
loss_function=training_conf["loss"],
train_dataset=train,
eval_dataset=eval,
eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
+29 -1
View File
@@ -1,11 +1,13 @@
import re
from typing import AnyStr, List
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer, T5Tokenizer
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
# @agoryuno contributed this
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
def webgpt_return_format(row):
@@ -97,6 +99,32 @@ def argument_parsing(parser):
return params
def get_datasets(dataset_list: List[AnyStr]):
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import ConcatDataset
train_datasets, evals = [], {}
for dataset_name in dataset_list:
if "webgpt" == dataset_name:
web_dataset = WebGPT()
train, eval = train_val_dataset(web_dataset, 0.2)
train_datasets.append(train)
evals["webgpt"] = eval
elif "hfsummary" == dataset_name:
sum_train = HFSummary(split="train")
train_datasets.append(sum_train)
sum_eval = HFSummary(split="valid1")
assert len(sum_eval) > 0
evals["hfsummary"] = sum_eval
elif "gptsynthetic" == dataset_name:
dataset = GPTJSynthetic()
train, eval = train_val_dataset(dataset, 0.1)
train_datasets.append(train)
evals["gptsynthetic"] = eval
train = ConcatDataset(train_datasets)
return train, evals
if __name__ == "__main__":
from transformers import AutoModelForSequenceClassification
+17 -24
View File
@@ -272,36 +272,29 @@ class TextLabel(str, enum.Enum):
obj.help_text = help_text
return obj
spam = "spam"
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
harmful = (
"harmful",
"Harmful content",
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
excessive_harm = (
"excessive_harm",
"Content likely to cause excessive harm not justifiable in the context",
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
)
sexual_content = "sexual_content", "Contains sexual content"
toxicity = "toxicity"
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
moral_judgement = "moral_judgement", "Expresses moral judgement"
political_content = "political_content"
humor = "humor"
sarcasm = "sarcasm"
hate_speech = "hate_speech"
profanity = "profanity"
ad_hominem = "ad_hominem"
insult = "insult"
threat = "threat"
aggressive = "aggressive"
misleading = "misleading"
helpful = "helpful"
formal = "formal"
cringe = "cringe"
creative = "creative"
beautiful = "beautiful"
informative = "informative"
based = "based"
slang = "slang"
political_content = "political_content", "Expresses political views"
humor = "humor", "Contains humorous content including sarcasm"
hate_speech = (
"hate_speech",
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
)
threat = "threat", "Contains a threat against a person or persons"
misleading = "misleading", "Contains text which is incorrect or misleading"
helpful = "helpful", "Completes the task to a high standard"
creative = "creative", "Expresses creativity in responding to the task"
class TextLabels(Interaction):
+2 -1
View File
@@ -8,7 +8,8 @@
"rules": {
"unused-imports/no-unused-imports": "warn",
"simple-import-sort/imports": "warn",
"simple-import-sort/exports": "warn"
"simple-import-sort/exports": "warn",
"eqeqeq": "warn"
},
"plugins": ["simple-import-sort", "unused-imports"]
}
+2
View File
@@ -39,5 +39,7 @@ next-env.d.ts
*.swp
# cypress
/cypress/screenshots
/cypress/videos
/cypress-visual-screenshots/diff
/cypress-visual-screenshots/comparison
+28
View File
@@ -153,6 +153,34 @@ When writing code for the website, we have a few best practices:
1. Define functional React components (with types for all properties when
feasible).
### Developing New Features
When working on new features or making significant changes that can't be done
within a single Pull Request, we ask that you make use of Feature Flags.
We've set up
[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to
make this easier. To get started:
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you
can copy as an example. Be sure to `isActive` to true when testing your
features but false when submitting your PR.
1. Use your flag wherever you add a new UI element. This can be done with:
```js
import { Flags } from "react-feature-flags";
...
<Flags authorizedFlags={["yourFlagName"]}>
<YourNewComponent />
</Flags>
```
You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`.
1. Once you've finished building out the feature and it is ready for everyone
to use, it's safe to remove the `Flag` wrappers around your component and
the entry in `flags.ts`.
### URL Paths
To use stable and consistent URL paths, we recommend the following strategy for
+375 -10
View File
@@ -29,6 +29,7 @@
"eslint-config-next": "13.0.6",
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
"formik": "^2.2.9",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"next": "13.0.6",
@@ -38,6 +39,7 @@
"postcss-focus-visible": "^7.1.0",
"react": "18.2.0",
"react-dom": "18.2.0",
"react-feature-flags": "^1.0.0",
"react-icons": "^4.7.1",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
@@ -56,7 +58,7 @@
"@storybook/manager-webpack5": "^6.5.15",
"@storybook/react": "^6.5.15",
"@storybook/testing-library": "^0.0.13",
"@types/node": "18.11.17",
"@types/node": "^18.11.17",
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
@@ -66,7 +68,8 @@
"eslint-plugin-unused-imports": "^2.0.0",
"prettier": "2.8.1",
"prisma": "^4.7.1",
"typescript": "4.9.4"
"ts-node": "^10.9.1",
"typescript": "^4.9.4"
}
},
"node_modules/@ampproject/remapping": {
@@ -3122,6 +3125,28 @@
"node": ">=0.1.90"
}
},
"node_modules/@cspotcode/source-map-support": {
"version": "0.8.1",
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
"devOptional": true,
"dependencies": {
"@jridgewell/trace-mapping": "0.3.9"
},
"engines": {
"node": ">=12"
}
},
"node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": {
"version": "0.3.9",
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
"devOptional": true,
"dependencies": {
"@jridgewell/resolve-uri": "^3.0.3",
"@jridgewell/sourcemap-codec": "^1.4.10"
}
},
"node_modules/@cypress/request": {
"version": "2.88.10",
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
@@ -10836,6 +10861,30 @@
"@testing-library/dom": ">=7.21.4"
}
},
"node_modules/@tsconfig/node10": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
"integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
"devOptional": true
},
"node_modules/@tsconfig/node12": {
"version": "1.0.11",
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
"devOptional": true
},
"node_modules/@tsconfig/node14": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
"devOptional": true
},
"node_modules/@tsconfig/node16": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
"integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
"devOptional": true
},
"node_modules/@types/aria-query": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
@@ -10975,7 +11024,7 @@
"version": "18.11.17",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
"dev": true
"devOptional": true
},
"node_modules/@types/node-fetch": {
"version": "2.6.2",
@@ -12013,7 +12062,7 @@
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
"dev": true
"devOptional": true
},
"node_modules/argparse": {
"version": "1.0.10",
@@ -14691,6 +14740,12 @@
"sha.js": "^2.4.8"
}
},
"node_modules/create-require": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
"devOptional": true
},
"node_modules/cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -15462,6 +15517,15 @@
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
},
"node_modules/diff": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
"devOptional": true,
"engines": {
"node": ">=0.3.1"
}
},
"node_modules/diffie-hellman": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
@@ -17688,6 +17752,47 @@
"node": ">= 6"
}
},
"node_modules/formik": {
"version": "2.2.9",
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
"funding": [
{
"type": "individual",
"url": "https://opencollective.com/formik"
}
],
"dependencies": {
"deepmerge": "^2.1.1",
"hoist-non-react-statics": "^3.3.0",
"lodash": "^4.17.21",
"lodash-es": "^4.17.21",
"react-fast-compare": "^2.0.1",
"tiny-warning": "^1.0.2",
"tslib": "^1.10.0"
},
"peerDependencies": {
"react": ">=16.8.0"
}
},
"node_modules/formik/node_modules/deepmerge": {
"version": "2.2.1",
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==",
"engines": {
"node": ">=0.10.0"
}
},
"node_modules/formik/node_modules/react-fast-compare": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
},
"node_modules/formik/node_modules/tslib": {
"version": "1.14.1",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
},
"node_modules/forwarded": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
@@ -20409,8 +20514,12 @@
"node_modules/lodash": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
"dev": true
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
},
"node_modules/lodash-es": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
},
"node_modules/lodash.debounce": {
"version": "4.0.8",
@@ -20690,6 +20799,12 @@
"semver": "bin/semver"
}
},
"node_modules/make-error": {
"version": "1.3.6",
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
"devOptional": true
},
"node_modules/makeerror": {
"version": "1.0.12",
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
@@ -26310,6 +26425,16 @@
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
},
"node_modules/react-feature-flags": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
"integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
"peerDependencies": {
"prop-types": "^15.5.4",
"react": ">= 16.3.0",
"react-dom": ">= 16.3.0"
}
},
"node_modules/react-focus-lock": {
"version": "2.9.2",
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
@@ -29081,6 +29206,11 @@
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
},
"node_modules/tiny-warning": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
},
"node_modules/tmp": {
"version": "0.2.1",
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
@@ -29247,6 +29377,70 @@
"node": ">=6.10"
}
},
"node_modules/ts-node": {
"version": "10.9.1",
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
"integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
"devOptional": true,
"dependencies": {
"@cspotcode/source-map-support": "^0.8.0",
"@tsconfig/node10": "^1.0.7",
"@tsconfig/node12": "^1.0.7",
"@tsconfig/node14": "^1.0.0",
"@tsconfig/node16": "^1.0.2",
"acorn": "^8.4.1",
"acorn-walk": "^8.1.1",
"arg": "^4.1.0",
"create-require": "^1.1.0",
"diff": "^4.0.1",
"make-error": "^1.1.1",
"v8-compile-cache-lib": "^3.0.1",
"yn": "3.1.1"
},
"bin": {
"ts-node": "dist/bin.js",
"ts-node-cwd": "dist/bin-cwd.js",
"ts-node-esm": "dist/bin-esm.js",
"ts-node-script": "dist/bin-script.js",
"ts-node-transpile-only": "dist/bin-transpile.js",
"ts-script": "dist/bin-script-deprecated.js"
},
"peerDependencies": {
"@swc/core": ">=1.2.50",
"@swc/wasm": ">=1.2.50",
"@types/node": "*",
"typescript": ">=2.7"
},
"peerDependenciesMeta": {
"@swc/core": {
"optional": true
},
"@swc/wasm": {
"optional": true
}
}
},
"node_modules/ts-node/node_modules/acorn": {
"version": "8.8.1",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
"integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
"devOptional": true,
"bin": {
"acorn": "bin/acorn"
},
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/ts-node/node_modules/acorn-walk": {
"version": "8.2.0",
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
"devOptional": true,
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/ts-pnp": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
@@ -29951,6 +30145,12 @@
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
"dev": true
},
"node_modules/v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
"devOptional": true
},
"node_modules/v8-to-istanbul": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
@@ -30850,6 +31050,15 @@
"fd-slicer": "~1.1.0"
}
},
"node_modules/yn": {
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
"devOptional": true,
"engines": {
"node": ">=6"
}
},
"node_modules/yocto-queue": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
@@ -33044,6 +33253,27 @@
"dev": true,
"optional": true
},
"@cspotcode/source-map-support": {
"version": "0.8.1",
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
"devOptional": true,
"requires": {
"@jridgewell/trace-mapping": "0.3.9"
},
"dependencies": {
"@jridgewell/trace-mapping": {
"version": "0.3.9",
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
"devOptional": true,
"requires": {
"@jridgewell/resolve-uri": "^3.0.3",
"@jridgewell/sourcemap-codec": "^1.4.10"
}
}
}
},
"@cypress/request": {
"version": "2.88.10",
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
@@ -38884,6 +39114,30 @@
"@babel/runtime": "^7.12.5"
}
},
"@tsconfig/node10": {
"version": "1.0.9",
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
"integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
"devOptional": true
},
"@tsconfig/node12": {
"version": "1.0.11",
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
"devOptional": true
},
"@tsconfig/node14": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
"devOptional": true
},
"@tsconfig/node16": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
"integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
"devOptional": true
},
"@types/aria-query": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
@@ -39023,7 +39277,7 @@
"version": "18.11.17",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
"dev": true
"devOptional": true
},
"@types/node-fetch": {
"version": "2.6.2",
@@ -39875,7 +40129,7 @@
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
"dev": true
"devOptional": true
},
"argparse": {
"version": "1.0.10",
@@ -41964,6 +42218,12 @@
"sha.js": "^2.4.8"
}
},
"create-require": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
"devOptional": true
},
"cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -42547,6 +42807,12 @@
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
},
"diff": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
"devOptional": true
},
"diffie-hellman": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
@@ -44293,6 +44559,37 @@
"mime-types": "^2.1.12"
}
},
"formik": {
"version": "2.2.9",
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
"requires": {
"deepmerge": "^2.1.1",
"hoist-non-react-statics": "^3.3.0",
"lodash": "^4.17.21",
"lodash-es": "^4.17.21",
"react-fast-compare": "^2.0.1",
"tiny-warning": "^1.0.2",
"tslib": "^1.10.0"
},
"dependencies": {
"deepmerge": {
"version": "2.2.1",
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA=="
},
"react-fast-compare": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
},
"tslib": {
"version": "1.14.1",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
}
}
},
"forwarded": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
@@ -46307,8 +46604,12 @@
"lodash": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
"dev": true
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
},
"lodash-es": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
},
"lodash.debounce": {
"version": "4.0.8",
@@ -46525,6 +46826,12 @@
}
}
},
"make-error": {
"version": "1.3.6",
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
"devOptional": true
},
"makeerror": {
"version": "1.0.12",
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
@@ -50542,6 +50849,12 @@
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
},
"react-feature-flags": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
"integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
"requires": {}
},
"react-focus-lock": {
"version": "2.9.2",
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
@@ -52722,6 +53035,11 @@
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
},
"tiny-warning": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
},
"tmp": {
"version": "0.2.1",
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
@@ -52852,6 +53170,41 @@
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==",
"dev": true
},
"ts-node": {
"version": "10.9.1",
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
"integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
"devOptional": true,
"requires": {
"@cspotcode/source-map-support": "^0.8.0",
"@tsconfig/node10": "^1.0.7",
"@tsconfig/node12": "^1.0.7",
"@tsconfig/node14": "^1.0.0",
"@tsconfig/node16": "^1.0.2",
"acorn": "^8.4.1",
"acorn-walk": "^8.1.1",
"arg": "^4.1.0",
"create-require": "^1.1.0",
"diff": "^4.0.1",
"make-error": "^1.1.1",
"v8-compile-cache-lib": "^3.0.1",
"yn": "3.1.1"
},
"dependencies": {
"acorn": {
"version": "8.8.1",
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
"integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
"devOptional": true
},
"acorn-walk": {
"version": "8.2.0",
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
"devOptional": true
}
}
},
"ts-pnp": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
@@ -53362,6 +53715,12 @@
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
"dev": true
},
"v8-compile-cache-lib": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
"devOptional": true
},
"v8-to-istanbul": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
@@ -54083,6 +54442,12 @@
"fd-slicer": "~1.1.0"
}
},
"yn": {
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
"devOptional": true
},
"yocto-queue": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
+8 -2
View File
@@ -18,6 +18,9 @@
"fix:format": "prettier --write ./src",
"fix": "npm run fix:format && npm run fix:lint"
},
"prisma": {
"seed": "ts-node --compiler-options {\"module\":\"CommonJS\"} prisma/seed.ts"
},
"dependencies": {
"@chakra-ui/react": "^2.4.4",
"@dnd-kit/core": "^6.0.6",
@@ -40,6 +43,7 @@
"eslint-config-next": "13.0.6",
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
"formik": "^2.2.9",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"next": "13.0.6",
@@ -49,6 +53,7 @@
"postcss-focus-visible": "^7.1.0",
"react": "18.2.0",
"react-dom": "18.2.0",
"react-feature-flags": "^1.0.0",
"react-icons": "^4.7.1",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
@@ -67,7 +72,7 @@
"@storybook/manager-webpack5": "^6.5.15",
"@storybook/react": "^6.5.15",
"@storybook/testing-library": "^0.0.13",
"@types/node": "18.11.17",
"@types/node": "^18.11.17",
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
@@ -77,6 +82,7 @@
"eslint-plugin-unused-imports": "^2.0.0",
"prettier": "2.8.1",
"prisma": "^4.7.1",
"typescript": "4.9.4"
"ts-node": "^10.9.1",
"typescript": "^4.9.4"
}
}
+57
View File
@@ -0,0 +1,57 @@
/**
* A seed function to inject test data into the web database.
*
* Use by running
* npx prisma db seed
*/
import { PrismaClient } from "@prisma/client";
const prisma = new PrismaClient();
async function main() {
const users = [
{ email: "general.user.a@example.com", name: "A", role: "general" },
{ email: "general.user.b@example.com", name: "B", role: "general" },
{ email: "general.user.c@example.com", name: "C", role: "general" },
{ email: "general.user.d@example.com", name: "D", role: "general" },
{ email: "general.user.e@example.com", name: "E", role: "general" },
{ email: "general.user.f@example.com", name: "F", role: "general" },
{ email: "general.user.g@example.com", name: "G", role: "general" },
{ email: "general.user.h@example.com", name: "H", role: "general" },
{ email: "general.user.i@example.com", name: "I", role: "general" },
{ email: "general.user.j@example.com", name: "J", role: "general" },
{ email: "general.user.k@example.com", name: "K", role: "general" },
{ email: "general.user.l@example.com", name: "L", role: "general" },
{ email: "general.user.m@example.com", name: "M", role: "general" },
{ email: "general.user.n@example.com", name: "N", role: "general" },
{ email: "general.user.o@example.com", name: "O", role: "general" },
{ email: "general.user.p@example.com", name: "P", role: "general" },
{ email: "general.user.q@example.com", name: "Q", role: "general" },
{ email: "general.user.r@example.com", name: "R", role: "general" },
{ email: "malicious.user.1@example.com", name: "M1", role: "general" },
{ email: "malicious.user.2@example.com", name: "M2", role: "general" },
];
await Promise.all(
users.map(async ({ email, name, role }) => {
await prisma.user.upsert({
where: { email },
update: { name, role },
create: {
email,
name,
role,
},
});
})
);
}
main()
.then(async () => {
await prisma.$disconnect();
})
.catch(async (e) => {
console.error(e);
await prisma.$disconnect();
process.exit(1);
});
+59 -5
View File
@@ -1,9 +1,63 @@
import { Button, ButtonProps } from "@chakra-ui/react";
import {
Button,
ButtonProps,
Menu,
MenuButton,
MenuItem,
MenuList,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Textarea,
useDisclosure,
} from "@chakra-ui/react";
import { useState } from "react";
import { FaChevronDown } from "react-icons/fa";
interface SkipButtonProps extends ButtonProps {
onSkip: (reason: string) => void;
}
export const SkipButton = ({ onSkip, ...props }: SkipButtonProps) => {
const { isOpen, onOpen: showModal, onClose: closeModal } = useDisclosure();
const [value, setValue] = useState("");
const onSubmit = () => {
onSkip(value);
setValue("");
closeModal();
};
export const SkipButton = ({ children, ...props }: ButtonProps) => {
return (
<Button size="lg" variant="outline" {...props}>
{children}
</Button>
<>
<Button size="lg" variant="outline" onClick={showModal} {...props}>
Skip
</Button>
<Modal isOpen={isOpen} onClose={closeModal}>
<ModalOverlay />
<ModalContent>
<ModalHeader>Skip</ModalHeader>
<ModalCloseButton />
<ModalBody>
<Textarea
value={value}
onChange={(e) => setValue(e.target.value)}
resize="none"
placeholder="Any feedback on this task?"
/>
</ModalBody>
<ModalFooter>
<Button colorScheme="blue" mr={3} onClick={onSubmit}>
Send
</Button>
</ModalFooter>
</ModalContent>
</Modal>
</>
);
};
@@ -3,7 +3,7 @@ import Link from "next/link";
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate];
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
export const TaskOption = () => {
const backgroundColor = useColorModeValue("white", "gray.700");
@@ -12,9 +12,9 @@ export const TaskOption = () => {
<Box className="flex flex-col gap-14" fontFamily="inter">
{displayTaskCategories.map((category, categoryIndex) => (
<div key={categoryIndex}>
<Text className="text-2xl font-bold pb-4">{TaskCategory[category]}</Text>
<Text className="text-2xl font-bold pb-4">{category}</Text>
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
{TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => (
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
<Link key={itemIndex} href={item.pathname}>
<GridItem
bg={backgroundColor}
+39 -33
View File
@@ -29,7 +29,7 @@ import useSWRMutation from "swr/mutation";
export const FlaggableElement = (props) => {
const [isEditing, setIsEditing] = useBoolean();
const { trigger } = useSWRMutation("/api/v1/text_labels", poster, {
const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off;
},
@@ -42,7 +42,12 @@ export const FlaggableElement = (props) => {
label_map.set(flag.attributeName, sliderValues[i]);
}
});
trigger({ post_id: props.post_id, label_map: Object.fromEntries(label_map), text: props.text });
trigger({
message_id: props.message_id,
post_id: props.post_id,
label_map: Object.fromEntries(label_map),
text: props.text,
});
};
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
@@ -118,7 +123,8 @@ export const FlaggableElement = (props) => {
</Popover>
);
};
function FlagCheckbox(props: {
export function FlagCheckbox(props: {
option: textFlagLabels;
idx: number;
checkboxValues: boolean[];
@@ -183,40 +189,40 @@ interface textFlagLabels {
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
// For the time being this list is configured on the FE.
// In the future it may be provided by the API.
// {
// attributeName: "fails_task",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation: "__TODO__",
// },
// {
// attributeName: "not_customer_assistant_appropriate",
// labelText: "Inappropriate for customer assistant",
// additionalExplanation: "__TODO__",
// },
{
attributeName: "fails_task",
labelText: "Fails to follow the correct instruction / task",
additionalExplanation: "__TODO__",
},
{
attributeName: "not_customer_assistant_appropriate",
labelText: "Inappropriate for customer assistant",
additionalExplanation: "__TODO__",
},
{
attributeName: "contains_sexual_content",
attributeName: "sexual_content",
labelText: "Contains sexual content",
},
{
attributeName: "contains_violent_content",
attributeName: "violence",
labelText: "Contains violent content",
},
{
attributeName: "encourages_violence",
labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
},
{
attributeName: "denigrates_a_protected_class",
labelText: "Denigrates a protected class",
},
{
attributeName: "gives_harmful_advice",
labelText: "Fails to follow the correct instruction / task",
additionalExplanation:
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
},
{
attributeName: "expresses_moral_judgement",
labelText: "Expresses moral judgement",
},
// {
// attributeName: "encourages_violence",
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
// },
// {
// attributeName: "denigrates_a_protected_class",
// labelText: "Denigrates a protected class",
// },
// {
// attributeName: "gives_harmful_advice",
// labelText: "Fails to follow the correct instruction / task",
// additionalExplanation:
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
// },
// {
// attributeName: "expresses_moral_judgement",
// labelText: "Expresses moral judgement",
// },
];
+4
View File
@@ -2,6 +2,7 @@ import { Box, Button, Text, useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
import { Flags } from "react-feature-flags";
import { FaUser } from "react-icons/fa";
import { UserMenu } from "./UserMenu";
@@ -42,6 +43,9 @@ export function Header(props) {
</Link>
</div>
<div className="flex items-center gap-4">
<Flags authorizedFlags={["flagTest"]}>
<div>FlagTest</div>
</Flags>
<AccountButton />
<UserMenu />
</div>
+19 -1
View File
@@ -1,7 +1,7 @@
// https://nextjs.org/docs/basic-features/layouts
import type { NextPage } from "next";
import { FiLayout, FiMessageSquare } from "react-icons/fi";
import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi";
import { Header } from "src/components/Header";
import { Footer } from "./Footer";
@@ -51,4 +51,22 @@ export const getDashboardLayout = (page: React.ReactElement) => (
</div>
);
export const getAdminLayout = (page: React.ReactElement) => (
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
<Header transparent={true} />
<SideMenuLayout
menuButtonOptions={[
{
label: "Users",
pathname: "/admin",
desc: "Users Dashboard",
icon: FiUsers,
},
]}
>
{page}
</SideMenuLayout>
</div>
);
export const noLayout = (page: React.ReactElement) => page;
@@ -1,7 +1,7 @@
import { Progress } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
export const LoadingScreen = ({ text }) => {
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
const { colorMode } = useColorMode();
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
+20 -18
View File
@@ -1,36 +1,38 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { useMemo } from "react";
import { FlaggableElement } from "./FlaggableElement";
export interface Message {
text: string;
is_assistant: boolean;
message_id: string;
}
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
if (colorMode === "light") {
return isAssistant ? "bg-slate-800" : "bg-sky-900";
} else {
return isAssistant ? "bg-black" : "bg-sky-900";
}
};
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
const { colorMode } = useColorMode();
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
const items = messages.map((messageProps: Message, i: number) => {
const { message_id, text } = messageProps;
return (
<FlaggableElement text={text} post_id={post_id} key={i + text}>
<div
key={i + text}
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
>
{text}
</div>
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
<MessageView {...messageProps} />
</FlaggableElement>
);
});
// Maybe also show a legend of the colors?
return <Grid gap={2}>{items}</Grid>;
};
export const MessageView = ({ is_assistant, text, message_id }: Message) => {
const { colorMode } = useColorMode();
const bgColor = useMemo(() => {
if (colorMode === "light") {
return is_assistant ? "bg-slate-800" : "bg-sky-900";
} else {
return is_assistant ? "bg-black" : "bg-sky-900";
}
}, [colorMode, is_assistant]);
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
};
+20 -12
View File
@@ -1,5 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import { Flex } from "@chakra-ui/react";
import clsx from "clsx";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
@@ -10,31 +11,38 @@ export interface TaskControlsProps {
tasks: any[];
className?: string;
onSubmitResponse: (task: { id: string }) => void;
onSkip: () => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
}
export const TaskControls = (props: TaskControlsProps) => {
const extraClases = props.className || "";
const { colorMode } = useColorMode();
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
const taskControlClases =
colorMode === "light"
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
const isLightMode = colorMode === "light";
const endTask = props.tasks[props.tasks.length - 1];
return (
<section className={taskControlClases}>
<section
className={clsx(
"flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto space-y-4 sm:space-y-0 sm:flex",
props.className,
{
"bg-white text-gray-800 shadow-lg": isLightMode,
"bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset": !isLightMode,
}
)}
>
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
<Flex justify="center" ml="auto" gap={2}>
<SkipButton>Skip</SkipButton>
<SkipButton
onSkip={(reason: string) => {
props.onSkipTask(props.tasks[0], reason);
}}
/>
{endTask.task.type !== "task_done" ? (
<SubmitButton colorScheme="blue" data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
Submit
</SubmitButton>
) : (
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onSkip}>
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onNextTask}>
Next Task
</SubmitButton>
)}
@@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => {
return (
<Stack direction={"column"}>
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} onCapture />
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} />
<Progress size={"md"} rounded={"md"} value={wordCount} colorScheme={progressColor} max={props.thresholds.goal} />
</Stack>
);
+1 -1
View File
@@ -1,6 +1,6 @@
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
return (
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
<div className="grid grid-cols-[min-content_auto] gap-x-2">
<b>Prompt</b>
<span data-cy="task-id">{id}</span>
<b>Output</b>
@@ -1,39 +0,0 @@
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
export type OptionProps = {
img: string;
alt: string;
title: string;
link: string;
};
export const TaskOption = (props: OptionProps) => {
const { alt, img, title, link } = props;
return (
<Link href={link}>
<Card
maxW="300"
minW="300"
minH="300"
maxH="300"
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
>
<CardBody width="full" height="full">
<Flex direction="column" alignItems="center" justifyContent="center">
<Image src={img} alt={alt} width={200} height={200} />
<Heading
mt={-10}
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
textAlign="center"
fontSize="3xl"
>
{title}
</Heading>
</Flex>
</CardBody>
</Card>
</Link>
);
};
@@ -1,23 +0,0 @@
import { Divider, Flex, Heading } from "@chakra-ui/react";
import React from "react";
export type TaskOptionsProps = {
title: string;
children: JSX.Element | JSX.Element[];
};
export const TaskOptions = (props: TaskOptionsProps) => {
const { title, children } = props;
return (
<Flex gap={10} wrap="wrap" justifyContent="center">
<Heading
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
fontSize="5xl"
>
{title}
</Heading>
<Divider mt={-8} />
{children}
</Flex>
);
};
@@ -1,73 +0,0 @@
import { Flex } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import React from "react";
import { TaskOption } from "./TaskOption";
import { TaskOptions } from "./TaskOptions";
export const TaskSelection = () => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
return (
<Flex
gap={10}
wrap="wrap"
justifyContent="space-evenly"
width="full"
height="full"
alignItems={"center"}
className={mainBgClasses}
>
<TaskOptions key="create" title="Create">
{/* <TaskOption
alt="Summarize Stories"
img="/images/logos/logo.svg"
title="Summarize stories"
link="/create/summarize_story"
/> */}
<TaskOption
alt="Create Initial Prompt"
img="/images/logos/logo.svg"
title="Create Initial Prompt"
link="/create/initial_prompt"
/>
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
<TaskOption
alt="Reply as Assistant"
img="/images/logos/logo.svg"
title="Reply as Assistant"
link="/create/assistant_reply"
/>
</TaskOptions>
<TaskOptions key="evaluate" title="Evaluate">
{/*
Commented out while the backend does not support them.
<TaskOption
alt="Rate Prompts"
img="/images/logos/logo.svg"
title="Rate Prompts"
link="/evaluate/rate_summary"
/> */}
<TaskOption
alt="Rank Initial Prompts"
img="/images/logos/logo.svg"
title="Rank Initial Prompts"
link="/evaluate/rank_initial_prompts"
/>
<TaskOption
alt="Rank User Replies"
img="/images/logos/logo.svg"
title="Rank User Replies"
link="/evaluate/rank_user_replies"
/>
<TaskOption
alt="Rank Assistant Replies"
img="/images/logos/logo.svg"
title="Rank Assistant Replies"
link="/evaluate/rank_assistant_replies"
/>
</TaskOptions>
</Flex>
);
};
@@ -1,3 +0,0 @@
export { TaskOption } from "./TaskOption";
export { TaskOptions } from "./TaskOptions";
export { TaskSelection } from "./TaskSelection";
+22 -7
View File
@@ -1,10 +1,22 @@
import { useState } from "react";
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskType } from "./TaskTypes";
export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
export interface CreateTaskProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
taskType: TaskType;
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
mainBgClasses: string;
}
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
const task = tasks[0].task;
const [inputText, setInputText] = useState("");
@@ -20,11 +32,6 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
});
};
const fetchNextTask = () => {
setInputText("");
mutate();
};
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
setInputText(event.target.value);
};
@@ -48,7 +55,15 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControls
tasks={tasks}
onSubmitResponse={submitResponse}
onSkipTask={(task, reason) => {
setInputText("");
onSkipTask(task, reason);
}}
onNextTask={onNextTask}
/>
</div>
);
};
+16 -6
View File
@@ -5,7 +5,17 @@ import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverr
import { MessageTable } from "../Messages/MessageTable";
export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
export interface EvaluateTaskProps {
// we need a task type
// eslint-disable-next-line @typescript-eslint/no-explicit-any
tasks: any[];
trigger: (update: { id: string; update_type: string; content: { ranking: number[] } }) => void;
onSkipTask: (task: { id: string }, reason: string) => void;
onNextTask: () => void;
mainBgClasses: string;
}
export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgClasses }: EvaluateTaskProps) => {
const [ranking, setRanking] = useState<number[]>([]);
const submitResponse = (task) => {
trigger({
@@ -17,10 +27,6 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
});
};
const fetchNextTask = () => {
setRanking([]);
mutate();
};
let messages = null;
if (tasks[0].task.conversation) {
messages = tasks[0].task.conversation.messages;
@@ -45,7 +51,11 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
isValid={ranking.length == tasks[0].task[sortables].length}
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
onSubmitResponse={submitResponse}
onSkip={fetchNextTask}
onSkipTask={(task, reason) => {
setRanking([]);
onSkipTask(task, reason);
}}
onNextTask={onNextTask}
/>
</div>
);
+100
View File
@@ -0,0 +1,100 @@
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { colors } from "styles/Theme/colors";
export const LabelTask = ({
title,
desc,
messages,
inputs,
controls,
}: {
title: string;
desc: string;
messages: ReactNode;
inputs: ReactNode;
controls: ReactNode;
}) => {
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
const card = useMemo(
() => (
<>
<h5 className="text-lg font-semibold">{title}</h5>
<p className="text-lg py-1">{desc}</p>
{messages}
</>
),
[title, desc, messages]
);
return (
<div className={`p-12 ${mainBgClasses}`}>
<TwoColumnsWithCards>
{card}
{inputs}
</TwoColumnsWithCards>
{controls}
</div>
);
};
// TODO: consolidate with FlaggableElement
interface LabelSliderGroupProps {
labelIDs: Array<string>;
onChange: (sliderValues: number[]) => unknown;
}
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
useEffect(() => {
onChange(sliderValues);
}, [sliderValues, onChange]);
return (
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
{labelIDs.map((labelId, idx) => (
<CheckboxSliderItem
key={idx}
labelId={labelId}
sliderValue={sliderValues[idx]}
sliderHandler={(sliderValue) => {
const newState = sliderValues.slice();
newState[idx] = sliderValue;
setSliderValues(newState);
}}
/>
))}
</Grid>
);
};
function CheckboxSliderItem(props: {
labelId: string;
sliderValue: number;
sliderHandler: (newVal: number) => unknown;
}) {
const id = useId();
const { colorMode } = useColorMode();
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
return (
<>
<label className="text-sm" htmlFor={id}>
{/* TODO: display real text instead of just the id */}
<span className={labelTextClass}>{props.labelId}</span>
</label>
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
<SliderTrack>
<SliderFilledTrack />
<SliderThumb />
</SliderTrack>
</Slider>
</>
);
}
+26 -2
View File
@@ -1,10 +1,25 @@
import { CreateTask } from "./CreateTask";
import { EvaluateTask } from "./EvaluateTask";
import { TaskCategory, TaskTypes } from "./TaskTypes";
import useSWRMutation from "swr/mutation";
import poster from "src/lib/poster";
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
const task = tasks[0].task;
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
onSuccess: async () => {
mutate();
},
});
const rejectTask = (task: { id: string }, reason: string) => {
sendRejection({
id: task.id,
reason,
});
};
function taskTypeComponent(type) {
const taskType = TaskTypes.find((taskType) => taskType.type === type);
const category = taskType.category;
@@ -14,13 +29,22 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
<CreateTask
tasks={tasks}
trigger={trigger}
mutate={mutate}
onSkipTask={rejectTask}
onNextTask={mutate}
taskType={taskType}
mainBgClasses={mainBgClasses}
/>
);
case TaskCategory.Evaluate:
return <EvaluateTask tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />;
return (
<EvaluateTask
tasks={tasks}
trigger={trigger}
onSkipTask={rejectTask}
onNextTask={mutate}
mainBgClasses={mainBgClasses}
/>
);
}
}
+38 -3
View File
@@ -1,9 +1,21 @@
export enum TaskCategory {
Create,
Evaluate,
Create = "Create",
Evaluate = "Evaluate",
Label = "Label",
}
export const TaskTypes = [
export interface TaskType {
label: string;
desc: string;
category: TaskCategory;
pathname: string;
type: string;
overview?: string;
instruction?: string;
}
export const TaskTypes: TaskType[] = [
// create
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
@@ -31,6 +43,7 @@ export const TaskTypes = [
overview: "Given the following conversation, provide an adequate reply",
instruction: "Provide the assistant`s reply",
},
// evaluate
{
label: "Rank User Replies",
category: TaskCategory.Evaluate,
@@ -52,4 +65,26 @@ export const TaskTypes = [
pathname: "/evaluate/rank_initial_prompts",
type: "rank_initial_prompts",
},
// label
{
label: "Label Initial Prompt",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_initial_prompt",
type: "label_initial_prompt",
},
{
label: "Label Prompter Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_prompter_reply",
type: "label_prompter_reply",
},
{
label: "Label Assistant Reply",
desc: "Provide labels for a prompt.",
category: TaskCategory.Label,
pathname: "/label/label_assistant_reply",
type: "label_assistant_reply",
},
];
+62 -25
View File
@@ -1,4 +1,18 @@
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
import {
Button,
Flex,
Spacer,
Stack,
Table,
TableCaption,
TableContainer,
Tbody,
Td,
Th,
Thead,
Tr,
} from "@chakra-ui/react";
import Link from "next/link";
import { useState } from "react";
import fetcher from "src/lib/fetcher";
import useSWR from "swr";
@@ -7,37 +21,60 @@ import useSWR from "swr";
* Fetches users from the users api route and then presents them in a simple Chakra table.
*/
const UsersCell = () => {
// Fetch and save the users.
const [pageIndex, setPageIndex] = useState(0);
const [users, setUsers] = useState([]);
const { isLoading } = useSWR("/api/admin/users", fetcher, {
// Fetch and save the users.
// This follows useSWR's recommendation for simple pagination:
// https://swr.vercel.app/docs/pagination#when-to-use-useswr
useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, {
onSuccess: setUsers,
});
const toPreviousPage = () => {
setPageIndex(Math.max(0, pageIndex - 1));
};
const toNextPage = () => {
setPageIndex(pageIndex + 1);
};
// Present users in a naive table.
return (
<TableContainer>
<Table variant="simple">
<TableCaption>Users</TableCaption>
<Thead>
<Tr>
<Th>Id</Th>
<Th>Email</Th>
<Th>Name</Th>
<Th>Role</Th>
</Tr>
</Thead>
<Tbody>
{users.map((user, index) => (
<Tr key={index}>
<Td>{user.id}</Td>
<Td>{user.email}</Td>
<Td>{user.name}</Td>
<Td>{user.role}</Td>
<Stack>
<Flex p="2">
<Button onClick={toPreviousPage}>Previous</Button>
<Spacer />
<Button onClick={toNextPage}>Next</Button>
</Flex>
<TableContainer>
<Table variant="simple">
<TableCaption>Users</TableCaption>
<Thead>
<Tr>
<Th>Id</Th>
<Th>Email</Th>
<Th>Name</Th>
<Th>Role</Th>
<Th>Update</Th>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
</Thead>
<Tbody>
{users.map((user, index) => (
<Tr key={index}>
<Td>{user.id}</Td>
<Td>{user.email}</Td>
<Td>{user.name}</Td>
<Td>{user.role}</Td>
<Td>
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
</Td>
</Tr>
))}
</Tbody>
</Table>
</TableContainer>
</Stack>
);
};
+3
View File
@@ -0,0 +1,3 @@
const flags = [{ name: "flagTest", isActive: false }];
export default flags;
@@ -0,0 +1,9 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface CreateInitialPromptTask {
id: string;
type: "initial_prompt";
hint: string;
}
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>("initial_prompt");
@@ -0,0 +1,24 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseCreateReplyTask {
id: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
type: "assistant_reply";
}
export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
type: "prompter_reply";
}
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>("assistant_reply");
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>("prompter_reply");
@@ -0,0 +1,9 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface RankInitialPromptsTask {
id: string;
type: "rank_initial_prompts";
prompts: string[];
}
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>("rank_initial_prompts");
@@ -0,0 +1,25 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
interface BaseRankRepliesTask {
id: string;
replies: string[];
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
interface RankAssistantRepliesTask extends BaseRankRepliesTask {
type: "rank_assistant_replies";
}
interface RankPrompterRepliesTask extends BaseRankRepliesTask {
type: "rank_prompter_replies";
}
export const useRankAssistantRepliesTask = () => useGenericTaskAPI<RankAssistantRepliesTask>("rank_assistant_replies");
export const useRankPrompterRepliesTask = () => useGenericTaskAPI<RankPrompterRepliesTask>("rank_prompter_replies");
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelAssistantReplyTask {
id: string;
type: LabelingTaskType.label_assistant_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
export const useLabelAssistantReplyTask = () =>
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
@@ -0,0 +1,15 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelInitialPromptTask {
id: string;
type: LabelingTaskType.label_initial_prompt;
message_id: string;
valid_labels: string[];
prompt: string;
}
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
export const useLabelInitialPromptTask = () =>
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
@@ -0,0 +1,22 @@
import { TaskResponse } from "../useGenericTaskAPI";
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
export interface LabelPrompterReplyTask {
id: string;
type: LabelingTaskType.label_prompter_reply;
message_id: string;
valid_labels: string[];
reply: string;
conversation: {
messages: Array<{
text: string;
is_assistant: boolean;
message_id: string;
}>;
};
}
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
export const useLabelPrompterReplyTask = () =>
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
@@ -0,0 +1,20 @@
import { useGenericTaskAPI } from "../useGenericTaskAPI";
export const enum LabelingTaskType {
label_initial_prompt = "label_initial_prompt",
label_prompter_reply = "label_prompter_reply",
label_assistant_reply = "label_assistant_reply",
}
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
console.assert(validLabels.length === labelWeights.length);
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
};
return { tasks, isLoading, submit, reset, error };
};
@@ -0,0 +1,38 @@
import { useState } from "react";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
// TODO: type & centralize types for all tasks
export interface TaskResponse<TaskType> {
id: string;
userId: string;
task: TaskType;
}
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
type ConcreteTaskResponse = TaskResponse<TaskType>;
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
"/api/new_task/" + taskApiEndpoint,
fetcher,
{
onSuccess: (data) => setTasks([data]),
revalidateOnMount: true,
dedupingInterval: 500,
}
);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (response) => {
const newTask: ConcreteTaskResponse = await response.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
return { tasks, isLoading, trigger, error, reset: mutate };
};
+19
View File
@@ -0,0 +1,19 @@
import type { NextApiRequest, NextApiResponse } from "next";
import { getToken } from "next-auth/jwt";
/**
* Wraps any API Route handler and verifies that the user has the appropriate
* role before running the handler. Returns a 403 otherwise.
*/
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
return async (req: NextApiRequest, res: NextApiResponse) => {
const token = await getToken({ req });
if (!token || token.role !== role) {
res.status(403).end();
return;
}
return handler(req, res);
};
};
export default withRole;
+7 -1
View File
@@ -42,7 +42,7 @@ export class OasstApiClient {
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
throw new OasstError(error.message, error.error_code, resp.status);
throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
@@ -68,6 +68,12 @@ export class OasstApiClient {
});
}
async nackTask(taskId: string, reason: string): Promise<void> {
return this.post(`/api/v1/tasks/${taskId}/nack`, {
reason,
});
}
// TODO return a strongly typed Task?
// This method is used to record interaction with task while fetching next task.
// This is a raw Json type, so we can't use it to strongly type the task.
+2 -2
View File
@@ -1,8 +1,8 @@
export { default } from "next-auth/middleware";
/**
* Guards all pages under `/grading` and redirects them to the sign in page.
* Guards these pages and redirects them to the sign in page.
*/
export const config = {
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
};
+7 -3
View File
@@ -3,7 +3,9 @@ import "focus-visible";
import type { AppProps } from "next/app";
import { SessionProvider } from "next-auth/react";
import { FlagsProvider } from "react-feature-flags";
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
import flags from "src/flags";
import { Chakra, getServerSideProps } from "../styles/Chakra";
@@ -16,9 +18,11 @@ function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: App
const page = getLayout(<Component {...pageProps} />);
return (
<Chakra cookies={cookies}>
<SessionProvider session={session}>{page}</SessionProvider>
</Chakra>
<FlagsProvider value={flags}>
<Chakra cookies={cookies}>
<SessionProvider session={session}>{page}</SessionProvider>
</Chakra>
</FlagsProvider>
);
}
export { getServerSideProps };
+3 -5
View File
@@ -2,7 +2,7 @@ import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { useEffect } from "react";
import { getTransparentHeaderLayout } from "src/components/Layout";
import { getAdminLayout } from "src/components/Layout";
import UsersCell from "src/components/UsersCell";
/**
@@ -26,10 +26,8 @@ const AdminIndex = () => {
return;
}
router.push("/");
}, [session, status]);
}, [router, session, status]);
// Show the final page.
// TODO(#237): Display a component that fetches actual user data.
return (
<>
<Head>
@@ -44,6 +42,6 @@ const AdminIndex = () => {
);
};
AdminIndex.getLayout = getTransparentHeaderLayout;
AdminIndex.getLayout = getAdminLayout;
export default AdminIndex;
@@ -0,0 +1,131 @@
import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
import { Field, Form, Formik } from "formik";
import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { useEffect } from "react";
import { getAdminLayout } from "src/components/Layout";
import poster from "src/lib/poster";
import prisma from "src/lib/prismadb";
import useSWRMutation from "swr/mutation";
const ManageUser = ({ user }) => {
const toast = useToast();
const router = useRouter();
const { data: session, status } = useSession();
// Check when the user session is loaded and re-route if the user is not an
// admin. This follows the suggestion by NextJS for handling private pages:
// https://nextjs.org/docs/api-reference/next/router#usage
//
// All admin pages should use the same check and routing steps.
useEffect(() => {
if (status === "loading") {
return;
}
if (session?.user?.role === "admin") {
return;
}
router.push("/");
}, [router, session, status]);
// Trigger to let us update the user's role. Triggers a toast when complete.
const { trigger } = useSWRMutation("/api/admin/update_user", poster, {
onSuccess: () => {
toast({
title: "User Role Updated",
status: "success",
duration: 1000,
isClosable: true,
});
},
onError: () => {
toast({
title: "User Role update failed",
status: "error",
duration: 1000,
isClosable: true,
});
},
});
return (
<>
<Head>
<title>Manage Users - Open Assistant</title>
<meta
name="description"
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
/>
</Head>
<Container className="oa-basic-theme">
<Formik
initialValues={user}
onSubmit={(values) => {
trigger(values);
}}
>
<Form>
<Field name="id" type="hidden" />
<Field name="name">
{({ field }) => (
<FormControl>
<FormLabel>Username</FormLabel>
<Input {...field} isDisabled />
</FormControl>
)}
</Field>
<Field name="email">
{({ field }) => (
<FormControl>
<FormLabel>Email</FormLabel>
<Input {...field} isDisabled />
</FormControl>
)}
</Field>
<Field name="role">
{({ field }) => (
<FormControl>
<FormLabel>Role</FormLabel>
<Select {...field}>
<option value="banned">Banned</option>
<option value="general">General</option>
<option value="admin">Admin</option>
</Select>
</FormControl>
)}
</Field>
<Button mt={4} type="submit">
Update
</Button>
</Form>
</Formik>
</Container>
</>
);
};
/**
* Fetch the user's data on the server side when rendering.
*/
export async function getServerSideProps({ query }) {
const user = await prisma.user.findUnique({
where: { id: query.id },
select: {
id: true,
name: true,
email: true,
role: true,
},
});
return {
props: {
user,
},
};
}
ManageUser.getLayout = getAdminLayout;
export default ManageUser;
@@ -0,0 +1,22 @@
import withRole from "src/lib/auth";
import prisma from "src/lib/prismadb";
/**
* Update's the user's data in the database. Accessible only to admins.
*/
const handler = withRole("admin", async (req, res) => {
const { id, role } = JSON.parse(req.body);
await prisma.user.update({
where: {
id,
},
data: {
role,
},
});
res.status(200).end();
});
export default handler;
+16 -13
View File
@@ -1,31 +1,34 @@
import { getToken } from "next-auth/jwt";
import client from "src/lib/prismadb";
import withRole from "src/lib/auth";
import prisma from "src/lib/prismadb";
// The number of users to fetch in any request.
const PAGE_SIZE = 20;
/**
* Returns a list of user results from the database when the requesting user is
* a logged in admin.
*/
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered or if the user isn't an admin.
if (!token || token.role !== "admin") {
res.status(403).end();
return;
}
const handler = withRole("admin", async (req, res) => {
// Figure out the pagination index and skip that number of users.
//
// Note: with Prisma this isn't the most efficient but it's the only possible
// option with cuid based User IDs.
const { pageIndex } = req.query;
const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0;
// Fetch 20 users.
const users = await client.user.findMany({
const users = await prisma.user.findMany({
select: {
id: true,
role: true,
name: true,
email: true,
},
take: 20,
skip,
take: PAGE_SIZE,
});
res.status(200).json(users);
};
});
export default handler;
@@ -36,9 +36,6 @@ const handler = async (req, res) => {
},
});
// Update the backend with our Task ID
await oasstApiClient.ackTask(task.id, registeredTask.id);
// Send the results to the client.
res.status(200).json(registeredTask);
};
+29
View File
@@ -0,0 +1,29 @@
import { Prisma } from "@prisma/client";
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
// Parse out the local task ID and the interaction contents.
const { id: frontendId, reason } = await JSON.parse(req.body);
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const task = registeredTask.task as Prisma.JsonObject;
const id = task.id as string;
// Update the backend with the rejection
await oasstApiClient.nackTask(id, reason);
// Send the results to the client.
res.status(200).json({});
};
export default handler;
+41
View File
@@ -0,0 +1,41 @@
import { getToken } from "next-auth/jwt";
import prisma from "src/lib/prismadb";
/**
* Sets the Label in the Backend.
*
*/
const handler = async (req, res) => {
const token = await getToken({ req });
// Return nothing if the user isn't registered.
if (!token) {
res.status(401).end();
return;
}
// Parse out the local message_id, task ID and the interaction contents.
const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
method: "POST",
headers: {
"X-API-Key": process.env.FASTAPI_KEY,
"Content-Type": "application/json",
},
body: JSON.stringify({
type: "text_labels",
message_id: message_id,
labels: label_map,
text: text,
user: {
id: token.sub,
display_name: token.name || token.email,
auth_method: "local",
},
}),
});
res.status(interactionRes.status).end();
};
export default handler;
+20 -6
View File
@@ -1,3 +1,4 @@
import { Prisma } from "@prisma/client";
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
@@ -6,9 +7,11 @@ import prisma from "src/lib/prismadb";
* Stores the task interaction with the Task Backend and then returns the next task generated.
*
* This implicity does a few things:
* 1) Stores the answer with the Task Backend.
* 2) Records the new task in our local database.
* 3) Returns the newly created task to the client.
* 1) Records the users answer in our local database.
* 2) Accepts the task.
* 3) Sends the users answer to the Task Backend.
* 4) Records the new task in our local database.
* 5) Returns the newly created task to the client.
*/
const handler = async (req, res) => {
const token = await getToken({ req });
@@ -20,7 +23,13 @@ const handler = async (req, res) => {
}
// Parse out the local task ID and the interaction contents.
const { id, content, update_type } = await JSON.parse(req.body);
const { id: frontendId, content, update_type } = await JSON.parse(req.body);
// Accept the task so that we can complete it, this will probably go away soon.
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
const task = registeredTask.task as Prisma.JsonObject;
const id = task.id as string;
await oasstApiClient.ackTask(id, registeredTask.id);
// Log the interaction locally to create our user_post_id needed by the Task
// Backend.
@@ -29,13 +38,18 @@ const handler = async (req, res) => {
content,
task: {
connect: {
id,
id: frontendId,
},
},
},
});
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
let newTask;
try {
newTask = await oasstApiClient.interactTask(update_type, frontendId, interaction.id, content, token);
} catch (err) {
return res.status(500).json(err);
}
// Stores the new task with our database.
const newRegisteredTask = await prisma.registeredTask.create({
+48 -1
View File
@@ -2,17 +2,59 @@ import { Button, Input, Stack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
import { useRouter } from "next/router";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
import React, { useRef } from "react";
import React, { useEffect, useRef, useState } from "react";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { AuthLayout } from "src/components/AuthLayout";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
export type SignInErrorTypes =
| "Signin"
| "OAuthSignin"
| "OAuthCallback"
| "OAuthCreateAccount"
| "EmailCreateAccount"
| "Callback"
| "OAuthAccountNotLinked"
| "EmailSignin"
| "CredentialsSignin"
| "SessionRequired"
| "default";
const errorMessages: Record<SignInErrorTypes, string> = {
Signin: "Try signing in with a different account.",
OAuthSignin: "Try signing in with a different account.",
OAuthCallback: "Try signing in with the same account you used originally.",
OAuthCreateAccount: "Try signing in with a different account.",
EmailCreateAccount: "Try signing in with a different account.",
Callback: "Try signing in with a different account.",
OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.",
EmailSignin: "The e-mail could not be sent.",
CredentialsSignin: "Sign in failed. Check the details you provided are correct.",
SessionRequired: "Please sign in to access this page.",
default: "Unable to sign in.",
};
// eslint-disable-next-line @typescript-eslint/no-unused-vars
function Signin({ csrfToken, providers }) {
const router = useRouter();
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
const [error, setError] = useState("");
useEffect(() => {
const err = router?.query?.error;
if (err) {
if (typeof err === "string") {
setError(errorMessages[err]);
} else {
setError(errorMessages[err[0]]);
}
}
}, [router]);
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
@@ -110,6 +152,11 @@ function Signin({ csrfToken, providers }) {
</Link>
.
</div>
{error && (
<div className="text-center mt-8">
<p className="text-orange-600">Error: {error}</p>
</div>
)}
</AuthLayout>
</div>
);
+9 -3
View File
@@ -1,17 +1,23 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { getCsrfToken, getProviders } from "next-auth/react";
import { AuthLayout } from "src/components/AuthLayout";
export default function Verify() {
const { colorMode } = useColorMode();
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
return (
<>
<Head>
<title>Sign Up - Open Assistant</title>
<meta name="Sign Up" content="Sign up to access Open Assistant" />
</Head>
<AuthLayout>
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
</AuthLayout>
<div className={`flex h-full justify-center items-center ${bgColorClass}`}>
<div className={bgColorClass}>
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
</div>
</div>
</>
);
}
+4 -27
View File
@@ -1,35 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
const AssistantReply = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -38,7 +15,7 @@ const AssistantReply = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
@@ -48,7 +25,7 @@ const AssistantReply = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
+4 -27
View File
@@ -1,35 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
const InitialPrompt = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/initial_prompt ", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -38,7 +15,7 @@ const InitialPrompt = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
@@ -48,7 +25,7 @@ const InitialPrompt = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
+6 -1
View File
@@ -87,7 +87,12 @@ const SummarizeStory = () => {
</>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControls
tasks={tasks}
onSubmitResponse={submitResponse}
onSkipTask={fetchNextTask}
onNextTask={fetchNextTask}
/>
</main>
</div>
);
+6 -34
View File
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
const UserReply = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/prompter_reply", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const UserReply = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
if (tasks.length === 0) {
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
}
return (
@@ -53,7 +25,7 @@ const UserReply = () => {
<title>Reply as Assistant</title>
<meta name="description" content="Reply as Assistant." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankAssistantReplies = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_assistant_replies", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankAssistantReplies = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
if (tasks.length === 0) {
return <Container className="p-6 text-center">No tasks found...</Container>;
}
return (
@@ -53,7 +25,7 @@ const RankAssistantReplies = () => {
<title>Rank Assistant Replies</title>
<meta name="description" content="Rank Assistant Replies." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
const RankInitialPrompts = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankInitialPrompts = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
if (tasks.length === 0) {
return <Container className="p-6 text-center">No tasks found...</Container>;
}
return (
@@ -53,7 +25,7 @@ const RankInitialPrompts = () => {
<title>Rank Initial Prompts</title>
<meta name="description" content="Rank initial prompts." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { useEffect, useState } from "react";
import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
import fetcher from "src/lib/fetcher";
import poster from "src/lib/poster";
import useSWRImmutable from "swr/immutable";
import useSWRMutation from "swr/mutation";
import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankUserReplies = () => {
const [tasks, setTasks] = useState([]);
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_prompter_replies", fetcher, {
onSuccess: (data) => {
setTasks([data]);
},
});
useEffect(() => {
if (tasks.length == 0) {
mutate();
}
}, [tasks]);
const { trigger } = useSWRMutation("/api/update_task", poster, {
onSuccess: async (data) => {
const newTask = await data.json();
setTasks((oldTasks) => [...oldTasks, newTask]);
},
});
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankUserReplies = () => {
return <LoadingScreen text="Loading..." />;
}
if (tasks.length == 0) {
return (
<div className={`p-12 ${mainBgClasses}`}>
<div className="flex h-full">
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
</div>
</div>
);
if (tasks.length === 0) {
return <Container className="p-6 text-center">No tasks found...</Container>;
}
return (
@@ -53,7 +25,7 @@ const RankUserReplies = () => {
<title>Rank User Replies</title>
<meta name="description" content="Rank User Replies." />
</Head>
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
</>
);
};
+6 -1
View File
@@ -102,7 +102,12 @@ const RateSummary = () => {
</section>
</TwoColumnsWithCards>
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
<TaskControls
tasks={tasks}
onSubmitResponse={submitResponse}
onSkipTask={fetchNextTask}
onNextTask={fetchNextTask}
/>
</main>
</>
);
@@ -0,0 +1,47 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelAssistantReplyTaskResponse,
useLabelAssistantReplyTask,
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
const LabelAssistantReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: true, message_id: task.message_id },
];
return (
<LabelTask
title="Label Assistant Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelAssistantReply;
@@ -0,0 +1,42 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { MessageView } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelInitialPromptTaskResponse,
useLabelInitialPromptTask,
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
const LabelInitialPrompt = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
return (
<LabelTask
title="Label Initial Prompt"
desc="Provide labels for the following prompt"
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelInitialPrompt;
@@ -0,0 +1,47 @@
import { useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import { TaskControls } from "src/components/Survey/TaskControls";
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
import {
LabelPrompterReplyTaskResponse,
useLabelPrompterReplyTask,
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
const LabelPrompterReply = () => {
const [sliderValues, setSliderValues] = useState<number[]>([]);
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
if (isLoading || tasks.length === 0) {
return <LoadingScreen />;
}
const task = tasks[0].task;
const messages: Message[] = [
...task.conversation.messages,
{ text: task.reply, is_assistant: false, message_id: task.message_id },
];
return (
<LabelTask
title="Label Prompter Reply"
desc="Given the following discussion, provide labels for the final prompt"
messages={<MessageTable messages={messages} />}
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
controls={
<TaskControls
tasks={tasks}
onSkipTask={() => reset()}
onNextTask={reset}
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
}
/>
}
/>
);
};
export default LabelPrompterReply;
+13 -13
View File
@@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
import Head from "next/head";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import useSWRImmutable from "swr/immutable";
@@ -10,29 +11,28 @@ const MessagesDashboard = () => {
const boxBgColor = useColorModeValue("white", "gray.700");
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
const [messages, setMessages] = useState([]);
const [userMessages, setUserMessages] = useState([]);
const [messages, setMessages] = useState<Message[]>(null);
const [userMessages, setUserMessages] = useState<Message[]>(null);
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
onSuccess: (data) => {
setMessages(data);
},
onSuccess: setMessages,
});
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
onSuccess: (data) => {
setUserMessages(data);
},
onSuccess: setUserMessages,
});
const receivedMessages = !isLoadingAll && Array.isArray(messages);
const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages);
useEffect(() => {
if (messages.length == 0) {
if (!receivedMessages) {
mutateAll();
}
if (userMessages.length == 0) {
if (!receivedUserMessages) {
mutateUser();
}
}, [messages, userMessages]);
}, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]);
return (
<>
@@ -52,7 +52,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
<Box>
@@ -66,7 +66,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
</Box>
</Box>
</SimpleGrid>
+19
View File
@@ -0,0 +1,19 @@
import Head from "next/head";
import { TaskOption } from "src/components/Dashboard";
import { getDashboardLayout } from "src/components/Layout";
const AllTasks = () => {
return (
<>
<Head>
<title>All Tasks - Open Assistant</title>
<meta name="description" content="All tasks for Open Assistant." />
</Head>
<TaskOption />
</>
);
};
AllTasks.getLayout = (page) => getDashboardLayout(page);
export default AllTasks;