diff --git a/.github/workflows/deploy-docs-site.yaml b/.github/workflows/deploy-docs-site.yaml
index d3fee6cb..721a5b30 100644
--- a/.github/workflows/deploy-docs-site.yaml
+++ b/.github/workflows/deploy-docs-site.yaml
@@ -8,6 +8,9 @@ on:
- ".github/workflows/deploy-docs-site.yaml"
- "docs/**"
pull_request:
+ paths:
+ - ".github/workflows/deploy-docs-site.yaml"
+ - "docs/**"
jobs:
deploy:
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
index 0f747b45..0f82185f 100644
--- a/.github/workflows/pre-commit.yaml
+++ b/.github/workflows/pre-commit.yaml
@@ -16,3 +16,12 @@ jobs:
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.0
+ - name: Post PR comment on failure
+ if: failure() && github.event_name == 'pull_request'
+ uses: peter-evans/create-or-update-comment@v2
+ with:
+ issue-number: ${{ github.event.pull_request.number }}
+ body: |
+ :x: **pre-commit** failed.
+ Please run `pre-commit run --all-files` locally and commit the changes.
+ Find more information in the repository's CONTRIBUTING.md
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index bb844a34..1bf0ac6a 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -29,6 +29,15 @@ jobs:
deploy-dev:
needs: [build-backend, build-web, build-bot]
runs-on: ubuntu-latest
+ env:
+ WEB_ADMIN_USERS: ${{ secrets.DEV_WEB_ADMIN_USERS }}
+ WEB_DISCORD_CLIENT_ID: ${{ secrets.DEV_WEB_DISCORD_CLIENT_ID }}
+ WEB_DISCORD_CLIENT_SECRET: ${{ secrets.DEV_WEB_DISCORD_CLIENT_SECRET }}
+ WEB_EMAIL_SERVER_HOST: ${{ secrets.DEV_WEB_EMAIL_SERVER_HOST }}
+ WEB_EMAIL_SERVER_PASSWORD: ${{ secrets.DEV_WEB_EMAIL_SERVER_PASSWORD }}
+ WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
+ WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
+ WEB_NEXTAUTH_SECRET: ${{ secrets.DEV_WEB_NEXTAUTH_SECRET }}
steps:
- name: Checkout
uses: actions/checkout@v2
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 608afe25..428f6a50 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -96,8 +96,10 @@ The website is built using Next.js and is in the `website` folder.
### Pre-commit
-Install `pre-commit` and run `pre-commit install` to install the pre-commit
-hooks.
+We are using `pre-commit` to enforce code style and formatting.
+
+Install `pre-commit` from [its website](https://pre-commit.com) and run
+`pre-commit install` to install the pre-commit hooks.
In case you haven't done this, have already committed, and CI is failing, you
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
diff --git a/ansible/README.md b/ansible/README.md
new file mode 100644
index 00000000..2ab1943e
--- /dev/null
+++ b/ansible/README.md
@@ -0,0 +1,7 @@
+To test the ansible playbook on localhost run
+`ansible-playbook -i test.inventory.ini dev.yaml`.\
+In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\
+Point Redis Insights to the Redis database by visiting localhost:8001 in a
+browser and select "I already have a database" followed by "Connect to a Redis
+Database".\
+For host, port and name fill in `oasst-redis`, `6379` and `redis`.
diff --git a/ansible/dev.yaml b/ansible/dev.yaml
index 577abd68..e63a6673 100644
--- a/ansible/dev.yaml
+++ b/ansible/dev.yaml
@@ -10,6 +10,39 @@
state: present
driver: bridge
+ - name: Copy redis.conf to managed node
+ ansible.builtin.copy:
+ src: ./redis.conf
+ dest: ./redis.conf
+
+ - name: Set up Redis
+ community.docker.docker_container:
+ name: oasst-redis
+ image: redis
+ state: started
+ restart_policy: always
+ network_mode: oasst
+ ports:
+ - 6379:6379
+ healthcheck:
+ test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
+ interval: 2s
+ timeout: 2s
+ retries: 10
+ command: redis-server /usr/local/etc/redis/redis.conf
+ volumes:
+ - "./redis.conf:/usr/local/etc/redis/redis.conf"
+
+ - name: Set up Redis Insights
+ community.docker.docker_container:
+ name: oasst-redis-insights
+ image: redislabs/redisinsight:latest
+ state: started
+ restart_policy: always
+ network_mode: oasst
+ ports:
+ - 8001:8001
+
- name: Create postgres containers
community.docker.docker_container:
name: "{{ item.name }}"
@@ -32,14 +65,6 @@
- name: oasst-postgres
- name: oasst-postgres-web
- - name: Set up maildev
- community.docker.docker_container:
- name: oasst-maildev
- image: maildev/maildev
- state: started
- restart_policy: always
- network_mode: oasst
-
- name: Run the oasst oasst-backend
community.docker.docker_container:
name: oasst-backend
@@ -51,6 +76,7 @@
network_mode: oasst
env:
POSTGRES_HOST: oasst-postgres
+ REDIS_HOST: oasst-redis
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
MAX_WORKERS: "1"
@@ -68,15 +94,27 @@
restart_policy: always
network_mode: oasst
env:
- FASTAPI_URL: http://oasst-backend:8080
- FASTAPI_KEY: "123"
+ ADMIN_USERS: "{{ lookup('ansible.builtin.env', 'WEB_ADMIN_USERS') }}"
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
- NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
- EMAIL_SERVER_HOST: oasst-maildev
- EMAIL_SERVER_PORT: "25"
- EMAIL_FROM: info@example.com
- NEXTAUTH_URL: http://web.dev.open-assistant.io/
DEBUG_LOGIN: "true"
+ DISCORD_CLIENT_ID:
+ "{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_ID') }}"
+ DISCORD_CLIENT_SECRET:
+ "{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_SECRET') }}"
+ EMAIL_FROM: open-assistent@laion.ai
+ EMAIL_SERVER_HOST:
+ "{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_HOST') }}"
+ EMAIL_SERVER_PASSWORD:
+ "{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PASSWORD') }}"
+ EMAIL_SERVER_PORT:
+ "{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PORT') }}"
+ EMAIL_SERVER_USER:
+ "{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
+ FASTAPI_URL: http://oasst-backend:8080
+ FASTAPI_KEY: "1234"
+ NEXTAUTH_SECRET:
+ "{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
+ NEXTAUTH_URL: http://web.dev.open-assistant.io/
ports:
- 3000:3000
command: bash wait-for-postgres.sh node server.js
diff --git a/ansible/redis.conf b/ansible/redis.conf
new file mode 100644
index 00000000..58da1e05
--- /dev/null
+++ b/ansible/redis.conf
@@ -0,0 +1,2 @@
+maxmemory 100mb
+maxmemory-policy allkeys-lru
diff --git a/ansible/test.inventory.ini b/ansible/test.inventory.ini
new file mode 100644
index 00000000..bfe6d93f
--- /dev/null
+++ b/ansible/test.inventory.ini
@@ -0,0 +1,2 @@
+[test]
+dev ansible_connection=local
diff --git a/backend/alembic/versions/2023_01_07_1250-ba61fe17fb6e_added_frontend_type_to_api_client.py b/backend/alembic/versions/2023_01_07_1250-ba61fe17fb6e_added_frontend_type_to_api_client.py
index dbc89ebf..31de6791 100644
--- a/backend/alembic/versions/2023_01_07_1250-ba61fe17fb6e_added_frontend_type_to_api_client.py
+++ b/backend/alembic/versions/2023_01_07_1250-ba61fe17fb6e_added_frontend_type_to_api_client.py
@@ -20,4 +20,4 @@ def upgrade() -> None:
def downgrade() -> None:
- op.drop_column("api_client", "frontend_id")
+ op.drop_column("api_client", "frontend_type")
diff --git a/backend/main.py b/backend/main.py
index 1c93fc9f..b84a2d9e 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
from oasst_backend.config import settings
from oasst_backend.database import engine
-from oasst_backend.prompt_repository import PromptRepository
+from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
with Session(engine) as db:
api_client = get_dummy_api_client(db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
- pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
+
+ ur = UserRepository(db=db, api_client=api_client)
+ tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
+ pr = PromptRepository(
+ db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
+ )
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
for msg in dummy_messages:
- task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
+ task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
- task = pr.store_task(
+ task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
for cmsg in conversation_messages
]
)
- task = pr.store_task(
+ task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
- pr.bind_frontend_message_id(task.id, msg.task_message_id)
+ tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py
index 420f0d1b..f149bebb 100644
--- a/backend/oasst_backend/api/v1/frontend_messages.py
+++ b/backend/oasst_backend/api/v1/frontend_messages.py
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
"""
Get a message by its frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
return utils.prepare_message(message)
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
Get a conversation from the tree root and up to the message with given frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_conversation(message)
return utils.prepare_conversation(messages)
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
Get all messages belonging to the same message tree.
Message is identified by its frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
"""
Get all messages belonging to the same message tree.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return utils.prepare_message_list(messages)
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
Get a subtree which starts with this message.
The message is identified by its frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
Get the longest conversation from the tree of the message.
The message is identified by its frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
Get message with the most children from the tree of the provided message.
The message is identified by its frontend ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py
index 0a745462..8d56b7f9 100644
--- a/backend/oasst_backend/api/v1/frontend_users.py
+++ b/backend/oasst_backend/api/v1/frontend_users.py
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
"""
Query frontend user messages.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
- pr = PromptRepository(db, api_client, None)
+ pr = PromptRepository(db, api_client)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py
index 4202edad..46aea637 100644
--- a/backend/oasst_backend/api/v1/leaderboards.py
+++ b/backend/oasst_backend/api/v1/leaderboards.py
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
-from oasst_backend.prompt_repository import PromptRepository
+from oasst_backend.user_repository import UserRepository
+from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
router = APIRouter()
@@ -11,15 +12,15 @@ router = APIRouter()
def get_assistant_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
-):
- pr = PromptRepository(db, api_client, None)
- return pr.get_user_leaderboard(role="assistant")
+) -> LeaderboardStats:
+ ur = UserRepository(db, api_client)
+ return ur.get_user_leaderboard(role="assistant")
@router.get("/create/prompter")
def get_prompter_leaderboard(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
-):
- pr = PromptRepository(db, api_client, None)
- return pr.get_user_leaderboard(role="prompter")
+) -> LeaderboardStats:
+ ur = UserRepository(db, api_client)
+ return ur.get_user_leaderboard(role="prompter")
diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py
index 7a2fd2e9..6229e20c 100644
--- a/backend/oasst_backend/api/v1/messages.py
+++ b/backend/oasst_backend/api/v1/messages.py
@@ -29,7 +29,7 @@ def query_messages(
"""
Query messages.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
api_client_id=api_client_id,
@@ -51,7 +51,7 @@ def get_message(
"""
Get a message by its internal ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
return utils.prepare_message(message)
@@ -64,7 +64,7 @@ def get_conv(
Get a conversation from the tree root and up to the message with given internal ID.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
messages = pr.fetch_message_conversation(message_id)
return utils.prepare_conversation(messages)
@@ -76,7 +76,7 @@ def get_tree(
"""
Get all messages belonging to the same message tree.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
tree = pr.fetch_message_tree(message.message_tree_id)
return utils.prepare_tree(tree, message.message_tree_id)
@@ -89,7 +89,7 @@ def get_children(
"""
Get all messages belonging to the same message tree.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
messages = pr.fetch_message_children(message_id)
return utils.prepare_message_list(messages)
@@ -101,7 +101,7 @@ def get_descendants(
"""
Get a subtree which starts with this message.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
descendants = pr.fetch_message_descendants(message)
return utils.prepare_tree(descendants, message.id)
@@ -114,7 +114,7 @@ def get_longest_conv(
"""
Get the longest conversation from the tree of the message.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
conv = pr.fetch_longest_conversation(message.message_tree_id)
return utils.prepare_conversation(conv)
@@ -127,7 +127,7 @@ def get_max_children(
"""
Get message with the most children from the tree of the provided message.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
return utils.prepare_tree([message, *children], message.id)
@@ -137,5 +137,5 @@ def get_max_children(
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
- pr = PromptRepository(db, api_client, None)
+ pr = PromptRepository(db, api_client)
pr.mark_messages_deleted(message_id)
diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py
index a54aa07b..1aaffb1b 100644
--- a/backend/oasst_backend/api/v1/stats.py
+++ b/backend/oasst_backend/api/v1/stats.py
@@ -13,5 +13,5 @@ def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
- pr = PromptRepository(db, api_client, None)
+ pr = PromptRepository(db, api_client)
return pr.get_stats()
diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py
index adfb2907..eb10dc00 100644
--- a/backend/oasst_backend/api/v1/tasks.py
+++ b/backend/oasst_backend/api/v1/tasks.py
@@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.api.v1.utils import prepare_conversation
-from oasst_backend.prompt_repository import PromptRepository
+from oasst_backend.prompt_repository import PromptRepository, TaskRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -190,9 +190,9 @@ def request_task(
api_client = deps.api_auth(api_key, db)
try:
- pr = PromptRepository(db, api_client, request.user)
+ pr = PromptRepository(db, api_client, client_user=request.user)
task, message_tree_id, parent_message_id = generate_task(request, pr)
- pr.store_task(task, message_tree_id, parent_message_id, request.collective)
+ pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -217,11 +217,11 @@ def tasks_acknowledge(
api_client = deps.api_auth(api_key, db)
try:
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
- pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
+ pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
@@ -245,8 +245,8 @@ def tasks_acknowledge_failure(
try:
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
api_client = deps.api_auth(api_key, db)
- pr = PromptRepository(db, api_client, user=None)
- pr.acknowledge_task_failure(task_id)
+ pr = PromptRepository(db, api_client)
+ pr.task_repository.acknowledge_task_failure(task_id)
except (KeyError, RuntimeError):
logger.exception("Failed to not acknowledge task.")
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
@@ -265,7 +265,7 @@ def tasks_interaction(
api_client = deps.api_auth(api_key, db)
try:
- pr = PromptRepository(db, api_client, user=interaction.user)
+ pr = PromptRepository(db, api_client, client_user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToMessage:
@@ -323,6 +323,6 @@ def close_collective_task(
api_key: APIKey = Depends(deps.get_api_key),
):
api_client = deps.api_auth(api_key, db)
- pr = PromptRepository(db, api_client, user=None)
- pr.close_task(close_task_request.message_id)
+ tr = TaskRepository(db, api_client)
+ tr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py
index 03fd2cb4..c9afd88c 100644
--- a/backend/oasst_backend/api/v1/text_labels.py
+++ b/backend/oasst_backend/api/v1/text_labels.py
@@ -25,7 +25,7 @@ def label_text(
try:
logger.info(f"Labeling text {text_labels=}.")
- pr = PromptRepository(db, api_client, user=text_labels.user)
+ pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py
index 8d55bfec..5dda88eb 100644
--- a/backend/oasst_backend/api/v1/users.py
+++ b/backend/oasst_backend/api/v1/users.py
@@ -29,7 +29,7 @@ def query_user_messages(
"""
Query user messages.
"""
- pr = PromptRepository(db, api_client, user=None)
+ pr = PromptRepository(db, api_client)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
@@ -48,6 +48,6 @@ def query_user_messages(
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
- pr = PromptRepository(db, api_client, None)
+ pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py
index 386595e9..97ad34eb 100644
--- a/backend/oasst_backend/models/message_tree_state.py
+++ b/backend/oasst_backend/models/message_tree_state.py
@@ -6,27 +6,56 @@ import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
-# The types of States a message tree can have.
+class States(str, Enum):
+ """States of the Open-Assistant message tree state machine."""
+
+ INITIAL_PROMPT_REVIEW = "initial_prompt_review"
+ """In this state the message tree consists only of a single inital prompt root node.
+ Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
+ `aborted_low_grade`."""
-class States(Enum):
- INITIAL = "initial"
BREEDING_PHASE = "breeding_phase"
+ """Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
+ are handed out to check if the quality of the replies surpasses the minimum acceptable
+ quality.
+ When the required number of messages passing the initial labelling-quality check has been
+ collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
+ are received the tree can also enter the `aborted_low_grade` state."""
+
RANKING_PHASE = "ranking_phase"
+ """The tree has been successfully populated with the desired number of messages. Ranking
+ tasks are now handed out for all nodes with more than one child."""
+
READY_FOR_SCORING = "ready_for_scoring"
- CHILDREN_SCORED = "children_scored"
- FINAL = "final"
+ """Required ranking responses have been collected and the scoring algorithm can now
+ compute the aggergated ranking scores that will appear in the dataset."""
+
+ READY_FOR_EXPORT = "ready_for_export"
+ """The Scoring algorithm computed rankings scores for all childern. The message tree can be
+ exported as part of an Open-Assistant message tree dataset."""
+
+ SCORING_FAILED = "scoring_failed"
+ """An exception occured in the scoring algorithm."""
+
+ ABORTED_LOW_GRADE = "aborted_low_grade"
+ """The system received too many bad reviews and stopped handing out tasks for this message tree."""
+
+ HALTED_BY_MODERATOR = "halted_by_moderator"
+ """A moderator decided to manually halt the message tree construction process."""
VALID_STATES = (
- States.INITIAL,
+ States.INITIAL_PROMPT_REVIEW,
States.BREEDING_PHASE,
States.RANKING_PHASE,
States.READY_FOR_SCORING,
- States.CHILDREN_SCORED,
- States.FINAL,
+ States.READY_FOR_EXPORT,
+ States.ABORTED_LOW_GRADE,
)
+TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
+
class MessageTreeState(SQLModel, table=True):
__tablename__ = "message_tree_state"
diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py
index a980c1b5..d923be97 100644
--- a/backend/oasst_backend/models/task.py
+++ b/backend/oasst_backend/models/task.py
@@ -35,4 +35,4 @@ class Task(SQLModel, table=True):
@property
def expired(self) -> bool:
- return self.expiry_date is not None and datetime.utcnow() < self.expiry_date
+ return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py
index 7c7dd7b6..7446ec07 100644
--- a/backend/oasst_backend/prompt_repository.py
+++ b/backend/oasst_backend/prompt_repository.py
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.journal_writer import JournalWriter
-from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
+from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
+from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
+from oasst_backend.user_repository import UserRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
-from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
+from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
- def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
+ def __init__(
+ self,
+ db: Session,
+ api_client: ApiClient,
+ client_user: Optional[protocol_schema.User] = None,
+ user_repository: Optional[UserRepository] = None,
+ task_repository: Optional[TaskRepository] = None,
+ ):
self.db = db
self.api_client = api_client
- self.user = self.lookup_user(user)
+ self.user_repository = user_repository or UserRepository(db, api_client)
+ self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
self.user_id = self.user.id if self.user else None
+ self.task_repository = task_repository or TaskRepository(
+ db, api_client, client_user, user_repository=self.user_repository
+ )
self.journal = JournalWriter(db, api_client, self.user)
- def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
- if not client_user:
- return None
- user: User = (
- self.db.query(User)
- .filter(
- User.api_client_id == self.api_client.id,
- User.username == client_user.id,
- User.auth_method == client_user.auth_method,
- )
- .first()
- )
- if user is None:
- # user is unknown, create new record
- user = User(
- username=client_user.id,
- display_name=client_user.display_name,
- api_client_id=self.api_client.id,
- auth_method=client_user.auth_method,
- )
- self.db.add(user)
- self.db.commit()
- self.db.refresh(user)
- elif client_user.display_name and client_user.display_name != user.display_name:
- # we found the user but the display name changed
- user.display_name = client_user.display_name
- self.db.add(user)
- self.db.commit()
- return user
-
- def validate_frontend_message_id(self, message_id: str) -> None:
- # TODO: Should it be replaced with fastapi/pydantic validation?
- if not isinstance(message_id, str):
- raise OasstError(
- f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
- )
- if not message_id:
- raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
-
- def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
- self.validate_frontend_message_id(frontend_message_id)
-
- # find task
- task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
- if task is None:
- raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
- if task.expired:
- raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
- if task.done or task.ack is not None:
- raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
-
- task.frontend_message_id = frontend_message_id
- task.ack = True
- # ToDo: check race-condition, transaction
- self.db.add(task)
- self.db.commit()
-
- def acknowledge_task_failure(self, task_id):
- # find task
- task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
- if task is None:
- raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
- if task.expired:
- raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
- if task.done or task.ack is not None:
- raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
-
- task.ack = False
- # ToDo: check race-condition, transaction
- self.db.add(task)
- self.db.commit()
-
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
- self.validate_frontend_message_id(frontend_message_id)
+ validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
@@ -113,20 +54,48 @@ class PromptRepository:
)
return message
- def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
- self.validate_frontend_message_id(message_id)
- task = (
- self.db.query(Task)
- .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
- .one_or_none()
+ def insert_message(
+ self,
+ *,
+ message_id: UUID,
+ frontend_message_id: str,
+ parent_id: UUID,
+ message_tree_id: UUID,
+ task_id: UUID,
+ role: str,
+ payload: db_payload.MessagePayload,
+ payload_type: str = None,
+ depth: int = 0,
+ ) -> Message:
+ if payload_type is None:
+ if payload is None:
+ payload_type = "null"
+ else:
+ payload_type = type(payload).__name__
+
+ message = Message(
+ id=message_id,
+ parent_id=parent_id,
+ message_tree_id=message_tree_id,
+ task_id=task_id,
+ user_id=self.user_id,
+ role=role,
+ frontend_message_id=frontend_message_id,
+ api_client_id=self.api_client.id,
+ payload_type=payload_type,
+ payload=PayloadContainer(payload=payload),
+ depth=depth,
)
- return task
+ self.db.add(message)
+ self.db.commit()
+ self.db.refresh(message)
+ return message
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
- self.validate_frontend_message_id(frontend_message_id)
- self.validate_frontend_message_id(user_frontend_message_id)
+ validate_frontend_message_id(frontend_message_id)
+ validate_frontend_message_id(user_frontend_message_id)
- task = self.fetch_task_by_frontend_message_id(frontend_message_id)
+ task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
@@ -174,7 +143,7 @@ class PromptRepository:
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
- task = self.fetch_task_by_frontend_message_id(rating.message_id)
+ task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
@@ -201,7 +170,7 @@ class PromptRepository:
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
- task = self.fetch_task_by_frontend_message_id(ranking.message_id)
+ task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -255,142 +224,6 @@ class PromptRepository:
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
- def store_task(
- self,
- task: protocol_schema.Task,
- message_tree_id: UUID = None,
- parent_message_id: UUID = None,
- collective: bool = False,
- ) -> Task:
- payload: db_payload.TaskPayload
- match type(task):
- case protocol_schema.SummarizeStoryTask:
- payload = db_payload.SummarizationStoryPayload(story=task.story)
-
- case protocol_schema.RateSummaryTask:
- payload = db_payload.RateSummaryPayload(
- full_text=task.full_text, summary=task.summary, scale=task.scale
- )
-
- case protocol_schema.InitialPromptTask:
- payload = db_payload.InitialPromptPayload(hint=task.hint)
-
- case protocol_schema.PrompterReplyTask:
- payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
-
- case protocol_schema.AssistantReplyTask:
- payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
-
- case protocol_schema.RankInitialPromptsTask:
- payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
-
- case protocol_schema.RankPrompterRepliesTask:
- payload = db_payload.RankPrompterRepliesPayload(
- type=task.type, conversation=task.conversation, replies=task.replies
- )
-
- case protocol_schema.RankAssistantRepliesTask:
- payload = db_payload.RankAssistantRepliesPayload(
- type=task.type, conversation=task.conversation, replies=task.replies
- )
-
- case protocol_schema.LabelInitialPromptTask:
- payload = db_payload.LabelInitialPromptPayload(
- type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
- )
-
- case protocol_schema.LabelPrompterReplyTask:
- payload = db_payload.LabelPrompterReplyPayload(
- type=task.type,
- message_id=task.message_id,
- conversation=task.conversation,
- reply=task.reply,
- valid_labels=task.valid_labels,
- )
-
- case protocol_schema.LabelAssistantReplyTask:
- payload = db_payload.LabelAssistantReplyPayload(
- type=task.type,
- message_id=task.message_id,
- conversation=task.conversation,
- reply=task.reply,
- valid_labels=task.valid_labels,
- )
-
- case _:
- raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
-
- task = self.insert_task(
- payload=payload,
- id=task.id,
- message_tree_id=message_tree_id,
- parent_message_id=parent_message_id,
- collective=collective,
- )
- assert task.id == task.id
- return task
-
- def insert_task(
- self,
- payload: db_payload.TaskPayload,
- id: UUID = None,
- message_tree_id: UUID = None,
- parent_message_id: UUID = None,
- collective: bool = False,
- ) -> Task:
- c = PayloadContainer(payload=payload)
- task = Task(
- id=id,
- user_id=self.user_id,
- payload_type=type(payload).__name__,
- payload=c,
- api_client_id=self.api_client.id,
- message_tree_id=message_tree_id,
- parent_message_id=parent_message_id,
- collective=collective,
- )
- self.db.add(task)
- self.db.commit()
- self.db.refresh(task)
- return task
-
- def insert_message(
- self,
- *,
- message_id: UUID,
- frontend_message_id: str,
- parent_id: UUID,
- message_tree_id: UUID,
- task_id: UUID,
- role: str,
- payload: db_payload.MessagePayload,
- payload_type: str = None,
- depth: int = 0,
- ) -> Message:
- if payload_type is None:
- if payload is None:
- payload_type = "null"
- else:
- payload_type = type(payload).__name__
-
- message = Message(
- id=message_id,
- parent_id=parent_id,
- message_tree_id=message_tree_id,
- task_id=task_id,
- user_id=self.user_id,
- role=role,
- frontend_message_id=frontend_message_id,
- api_client_id=self.api_client.id,
- payload_type=payload_type,
- payload=PayloadContainer(payload=payload),
- depth=depth,
- )
- self.db.add(message)
- self.db.commit()
- self.db.refresh(message)
- return message
-
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -515,28 +348,6 @@ class PromptRepository:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
- def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
- """
- Mark task as done. No further messages will be accepted for this task.
- """
- self.validate_frontend_message_id(frontend_message_id)
- task = self.fetch_task_by_frontend_message_id(frontend_message_id)
-
- if not task:
- raise OasstError(
- f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
- )
- if task.expired:
- raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
- if not allow_personal_tasks and not task.collective:
- raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
- if task.done:
- raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
-
- task.done = True
- self.db.add(task)
- self.db.commit()
-
@staticmethod
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
"""
@@ -728,24 +539,3 @@ class PromptRepository:
deleted=result.get(True, 0),
message_trees=result.get(None, 0),
)
-
- def get_user_leaderboard(self, role: str) -> LeaderboardStats:
- """
- Get leaderboard stats for Messages created,
- separate leaderboard for prompts & assistants
-
- """
- query = (
- self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
- .join(User, User.id == Message.user_id, isouter=True)
- .filter(Message.deleted is not True, Message.role == role)
- .group_by(Message.user_id, User.username, User.display_name)
- .order_by(func.count(Message.user_id).desc())
- )
-
- result = [
- {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
- for i, j in enumerate(query.all(), start=1)
- ]
-
- return LeaderboardStats(leaderboard=result)
diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py
new file mode 100644
index 00000000..15484d66
--- /dev/null
+++ b/backend/oasst_backend/task_repository.py
@@ -0,0 +1,199 @@
+from typing import Optional
+from uuid import UUID
+
+import oasst_backend.models.db_payload as db_payload
+from oasst_backend.models import ApiClient, Task
+from oasst_backend.models.payload_column_type import PayloadContainer
+from oasst_backend.user_repository import UserRepository
+from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
+from oasst_shared.schemas import protocol as protocol_schema
+from sqlmodel import Session
+from starlette.status import HTTP_404_NOT_FOUND
+
+
+def validate_frontend_message_id(message_id: str) -> None:
+ # TODO: Should it be replaced with fastapi/pydantic validation?
+ if not isinstance(message_id, str):
+ raise OasstError(
+ f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
+ )
+ if not message_id:
+ raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
+
+
+class TaskRepository:
+ def __init__(
+ self,
+ db: Session,
+ api_client: ApiClient,
+ client_user: Optional[protocol_schema.User],
+ user_repository: UserRepository,
+ ):
+ self.db = db
+ self.api_client = api_client
+ self.user_repository = user_repository
+ self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
+ self.user_id = self.user.id if self.user else None
+
+ def store_task(
+ self,
+ task: protocol_schema.Task,
+ message_tree_id: UUID = None,
+ parent_message_id: UUID = None,
+ collective: bool = False,
+ ) -> Task:
+ payload: db_payload.TaskPayload
+ match type(task):
+ case protocol_schema.SummarizeStoryTask:
+ payload = db_payload.SummarizationStoryPayload(story=task.story)
+
+ case protocol_schema.RateSummaryTask:
+ payload = db_payload.RateSummaryPayload(
+ full_text=task.full_text, summary=task.summary, scale=task.scale
+ )
+
+ case protocol_schema.InitialPromptTask:
+ payload = db_payload.InitialPromptPayload(hint=task.hint)
+
+ case protocol_schema.PrompterReplyTask:
+ payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
+
+ case protocol_schema.AssistantReplyTask:
+ payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
+
+ case protocol_schema.RankInitialPromptsTask:
+ payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
+
+ case protocol_schema.RankPrompterRepliesTask:
+ payload = db_payload.RankPrompterRepliesPayload(
+ type=task.type, conversation=task.conversation, replies=task.replies
+ )
+
+ case protocol_schema.RankAssistantRepliesTask:
+ payload = db_payload.RankAssistantRepliesPayload(
+ type=task.type, conversation=task.conversation, replies=task.replies
+ )
+
+ case protocol_schema.LabelInitialPromptTask:
+ payload = db_payload.LabelInitialPromptPayload(
+ type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
+ )
+
+ case protocol_schema.LabelPrompterReplyTask:
+ payload = db_payload.LabelPrompterReplyPayload(
+ type=task.type,
+ message_id=task.message_id,
+ conversation=task.conversation,
+ reply=task.reply,
+ valid_labels=task.valid_labels,
+ )
+
+ case protocol_schema.LabelAssistantReplyTask:
+ payload = db_payload.LabelAssistantReplyPayload(
+ type=task.type,
+ message_id=task.message_id,
+ conversation=task.conversation,
+ reply=task.reply,
+ valid_labels=task.valid_labels,
+ )
+
+ case _:
+ raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
+
+ task = self.insert_task(
+ payload=payload,
+ id=task.id,
+ message_tree_id=message_tree_id,
+ parent_message_id=parent_message_id,
+ collective=collective,
+ )
+ assert task.id == task.id
+ return task
+
+ def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
+ validate_frontend_message_id(frontend_message_id)
+
+ # find task
+ task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
+ if task is None:
+ raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
+ if task.expired:
+ raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
+ if task.done or task.ack is not None:
+ raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
+
+ task.frontend_message_id = frontend_message_id
+ task.ack = True
+ # ToDo: check race-condition, transaction
+ self.db.add(task)
+ self.db.commit()
+
+ def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
+ """
+ Mark task as done. No further messages will be accepted for this task.
+ """
+ validate_frontend_message_id(frontend_message_id)
+ task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
+
+ if not task:
+ raise OasstError(
+ f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
+ )
+ if task.expired:
+ raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
+ if not allow_personal_tasks and not task.collective:
+ raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
+ if task.done:
+ raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
+
+ task.done = True
+ self.db.add(task)
+ self.db.commit()
+
+ def acknowledge_task_failure(self, task_id):
+ # find task
+ task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
+ if task is None:
+ raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
+ if task.expired:
+ raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
+ if task.done or task.ack is not None:
+ raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
+
+ task.ack = False
+ # ToDo: check race-condition, transaction
+ self.db.add(task)
+ self.db.commit()
+
+ def insert_task(
+ self,
+ payload: db_payload.TaskPayload,
+ id: UUID = None,
+ message_tree_id: UUID = None,
+ parent_message_id: UUID = None,
+ collective: bool = False,
+ ) -> Task:
+ c = PayloadContainer(payload=payload)
+ task = Task(
+ id=id,
+ user_id=self.user_id,
+ payload_type=type(payload).__name__,
+ payload=c,
+ api_client_id=self.api_client.id,
+ message_tree_id=message_tree_id,
+ parent_message_id=parent_message_id,
+ collective=collective,
+ )
+ self.db.add(task)
+ self.db.commit()
+ self.db.refresh(task)
+ return task
+
+ def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
+ validate_frontend_message_id(message_id)
+ task = (
+ self.db.query(Task)
+ .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
+ .one_or_none()
+ )
+ return task
diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py
new file mode 100644
index 00000000..b5508899
--- /dev/null
+++ b/backend/oasst_backend/user_repository.py
@@ -0,0 +1,64 @@
+from typing import Optional
+
+from oasst_backend.models import ApiClient, Message, User
+from oasst_shared.schemas import protocol as protocol_schema
+from oasst_shared.schemas.protocol import LeaderboardStats
+from sqlmodel import Session, func
+
+
+class UserRepository:
+ def __init__(self, db: Session, api_client: ApiClient):
+ self.db = db
+ self.api_client = api_client
+
+ def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
+ if not client_user:
+ return None
+ user: User = (
+ self.db.query(User)
+ .filter(
+ User.api_client_id == self.api_client.id,
+ User.username == client_user.id,
+ User.auth_method == client_user.auth_method,
+ )
+ .first()
+ )
+ if user is None:
+ if create_missing:
+ # user is unknown, create new record
+ user = User(
+ username=client_user.id,
+ display_name=client_user.display_name,
+ api_client_id=self.api_client.id,
+ auth_method=client_user.auth_method,
+ )
+ self.db.add(user)
+ self.db.commit()
+ self.db.refresh(user)
+ elif client_user.display_name and client_user.display_name != user.display_name:
+ # we found the user but the display name changed
+ user.display_name = client_user.display_name
+ self.db.add(user)
+ self.db.commit()
+ return user
+
+ def get_user_leaderboard(self, role: str) -> LeaderboardStats:
+ """
+ Get leaderboard stats for Messages created,
+ separate leaderboard for prompts & assistants
+
+ """
+ query = (
+ self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
+ .join(User, User.id == Message.user_id, isouter=True)
+ .filter(Message.deleted is not True, Message.role == role)
+ .group_by(Message.user_id, User.username, User.display_name)
+ .order_by(func.count(Message.user_id).desc())
+ )
+
+ result = [
+ {"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
+ for i, j in enumerate(query.all(), start=1)
+ ]
+
+ return LeaderboardStats(leaderboard=result)
diff --git a/docs/docs/research/general.md b/docs/docs/research/general.md
index 56f935ac..4186ebac 100644
--- a/docs/docs/research/general.md
+++ b/docs/docs/research/general.md
@@ -1,7 +1,63 @@
-# General
+# Research
This page lists research papers that are relevant to the project.
+## Table of Contents
+
+- Reinforcement Learning from Human Feedback
+- Generating Text From Language Models
+- Automatically Generating Instruction Data for Training
+- Uncertainty Estimation of Language Model Outputs
+
+## Reinforcement Learning from Human Feedback
+
+Reinforcement Learning from Human Feedback (RLHF) is a method for fine-tuning a
+generative language models based on a reward model that is learned from human
+preference data. This method facilitates the learning of instruction-tuned
+models, among other things.
+
+### Learning to summarize from human feedback [[ArXiv](https://arxiv.org/pdf/2009.01325.pdf)], [[Github](https://github.com/openai/summarize-from-feedback)]
+
+> In this work, we show that it is possible to significantly improve summary
+> quality by training a model to optimize for human preferences. We collect a
+> large, high-quality dataset of human comparisons between summaries, train a
+> model to predict the human-preferred summary, and use that model as a reward
+> function to fine-tune a summarization policy using reinforcement learning.
+
+### Training language models to follow instructions with human feedback [[ArXiv](https://arxiv.org/pdf/2203.02155.pdf)]
+
+> Starting with a set of labeler-written prompts and prompts submitted through
+> the OpenAI API, we collect a dataset of labeler demonstrations of the desired
+> model behavior, which we use to fine-tune GPT-3 using supervised learning. We
+> then collect a dataset of rankings of model outputs, which we use to further
+> fine-tune this supervised model using reinforcement learning from human
+> feedback.
+
+### Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback [[ArXiv](https://arxiv.org/pdf/2204.05862.pdf)]
+
+> We apply preference modeling and reinforcement learning from human feedback
+> (RLHF) to finetune language models to act as helpful and harmless assistants.
+> We find this alignment training improves performance on almost all NLP
+> evaluations, and is fully compatible with training for specialized skills such
+> as python coding and summarization.
+
+## Generating Text From Language Models
+
+A language model generates output text token by token, autoregressively. The
+large search space of this task requires some method of narrowing down the set
+of tokens to be considered in each step. This method, in turn, has a big impact
+on the quality of the resulting text.
+
+### RANKGEN: Improving Text Generation with Large Ranking Models [[ArXiv](https://arxiv.org/pdf/2205.09726.pdf)], [[Github](https://github.com/martiansideofthemoon/rankgen)]
+
+> Given an input sequence (or prefix), modern language models often assign high
+> probabilities to output sequences that are repetitive, incoherent, or
+> irrelevant to the prefix; as such, model-generated text also contains such
+> artifacts. To address these issues we present RankGen, a 1.2B parameter
+> encoder model for English that scores model generations given a prefix.
+> RankGen can be flexibly incorporated as a scoring function in beam search and
+> used to decode from any pretrained language model.
+
## Automatically Generating Instruction Data for Training
This line of work is about significantly reducing the need for manually
@@ -32,3 +88,15 @@ models.
> rivals the effectiveness of training on open-source manually-curated datasets,
> surpassing the performance of models such as T0++ and Tk-Instruct across
> various benchmarks.
+
+## Uncertainty Estimation of Language Model Outputs
+
+### Teaching models to express their uncertainty in words [[Arxiv](https://arxiv.org/pdf/2205.14334.pdf)]
+
+> We show that a GPT-3 model can learn to express uncertainty about its own
+> answers in natural language -- without use of model logits. When given a
+> question, the model generates both an answer and a level of confidence (e.g.
+> "90% confidence" or "high confidence"). These levels map to probabilities that
+> are well calibrated. The model also remains moderately calibrated under
+> distribution shift, and is sensitive to uncertainty in its own answers, rather
+> than imitating human examples.
diff --git a/model/reward/instructor/rank_datasets.py b/model/reward/instructor/rank_datasets.py
index 5e7da948..bd53eb02 100644
--- a/model/reward/instructor/rank_datasets.py
+++ b/model/reward/instructor/rank_datasets.py
@@ -19,7 +19,7 @@
"""
from dataclasses import dataclass
-from typing import Optional, Union
+from typing import Dict, List, Optional, Union
import numpy as np
import torch
@@ -35,7 +35,7 @@ class RankGenCollator:
max_length: Optional[int] = None
max_examples: Optional[int] = None
- def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
+ def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
prefixes = []
better_answers = []
worse_answers = []
@@ -193,3 +193,47 @@ class HFSummary(Dataset):
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
# optimize the format later
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
+
+
+class HFDataset(Dataset):
+ """
+ This is a base huggingface dataset which written to support the
+ simplest pos-neg pair format
+
+ we should do something like this for supervised datasets
+ """
+
+ def __init__(
+ self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
+ ) -> None:
+ super().__init__()
+ dataset = load_dataset(dataset_name, subset)
+ if split is not None:
+ dataset = dataset[split]
+
+ self.questions = {}
+ self.index2question = {}
+ for row in dataset:
+ question = row[question_field].strip()
+ pos = row[pos_answer_field]
+ neg = row[neg_answer_field]
+ if question not in self.index2question:
+ self.index2question[len(self.index2question)] = question
+
+ if question not in self.questions:
+ self.questions[question] = []
+ self.questions[question].append((pos.strip(), neg.strip()))
+
+ def __len__(self):
+ return len(self.index2question)
+
+ def __getitem__(self, index):
+ question = self.index2question[index]
+ rows = self.questions[question]
+ # optimize the format later
+ return question, rows
+
+
+class GPTJSynthetic(HFDataset):
+ def __init__(self) -> None:
+ super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
diff --git a/model/reward/instructor/tests/test_dataset.py b/model/reward/instructor/tests/test_dataset.py
index 746a3c1e..832aace3 100644
--- a/model/reward/instructor/tests/test_dataset.py
+++ b/model/reward/instructor/tests/test_dataset.py
@@ -1,5 +1,5 @@
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
-from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
+from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
@@ -25,7 +25,7 @@ def test_webgpt():
print(batch["input_ids"].shape)
-def test_hf_quality():
+def test_hf_summary_quality():
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
@@ -35,6 +35,12 @@ def test_hf_quality():
print(batch["input_ids"].shape)
-if __name__ == "__main__":
- test_hf_quality()
- # test_webgpt()
+def test_gptj_dataset():
+ dataset = GPTJSynthetic()
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
+ collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024)
+
+ print(len(dataset))
+ dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
+ for batch in dataloader:
+ batch["input_ids"].shape
diff --git a/model/reward/instructor/trainer.py b/model/reward/instructor/trainer.py
index 68a58a38..940c0708 100644
--- a/model/reward/instructor/trainer.py
+++ b/model/reward/instructor/trainer.py
@@ -1,29 +1,23 @@
import os
from argparse import ArgumentParser
-from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import evaluate
import numpy as np
import torch
from models import RankGenModel
-from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
+from rank_datasets import DataCollatorForPairRank, RankGenCollator
from torch import nn
-from torch.utils.data import ConcatDataset, Dataset
from transformers import (
AdamW,
AutoModelForSequenceClassification,
- DataCollator,
- EvalPrediction,
PreTrainedModel,
- PreTrainedTokenizerBase,
Trainer,
- TrainerCallback,
TrainingArguments,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
-from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
+from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
os.environ["WANDB_PROJECT"] = "reward-model"
@@ -32,11 +26,6 @@ parser = ArgumentParser()
parser.add_argument("config", type=str)
-@dataclass
-class CustomTrainingArguments(TrainingArguments):
- loss_function: str = "rank"
-
-
def compute_metrics(eval_pred):
predictions, _ = eval_pred
predictions = np.argmax(predictions, axis=1)
@@ -60,31 +49,12 @@ class RankTrainer(Trainer):
model: Union[PreTrainedModel, nn.Module] = None,
model_name: str = None,
args: Optional[TrainingArguments] = None,
- data_collator: Optional[DataCollator] = None,
- train_dataset: Optional[Dataset] = None,
- eval_dataset: Optional[Dataset] = None,
- tokenizer: Optional[PreTrainedTokenizerBase] = None,
- model_init: Callable[[], PreTrainedModel] = None,
- compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
- callbacks: Optional[List[TrainerCallback]] = None,
- optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
+ loss_function: str = "rank",
+ **kwargs,
):
- super().__init__(
- model,
- args,
- data_collator,
- train_dataset,
- eval_dataset,
- tokenizer,
- model_init,
- compute_metrics,
- callbacks,
- optimizers,
- preprocess_logits_for_metrics,
- )
- self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
- self.loss_function = args.loss_function
+ super().__init__(model, args, **kwargs)
+ self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss()
+ self.loss_function = loss_function
self.model_name = model_name
def compute_loss(self, model, inputs, return_outputs=False):
@@ -163,11 +133,10 @@ if __name__ == "__main__":
params = sum([np.prod(p.size()) for p in model_parameters])
print("Number of trainable : {}M".format(int(params / 1e6)))
- args = CustomTrainingArguments(
+ args = TrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=training_conf["num_train_epochs"],
warmup_steps=500,
- loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=training_conf["fp16"],
@@ -184,22 +153,9 @@ if __name__ == "__main__":
save_steps=1000,
report_to="wandb",
)
- train_datasets, evals = [], {}
- if "webgpt" in training_conf["datasets"]:
- web_dataset = WebGPT()
- train, eval = train_val_dataset(web_dataset)
- train_datasets.append(train)
- evals["webgpt"] = eval
- if "hfsummary" in training_conf["datasets"]:
- sum_train = HFSummary(split="train")
- train_datasets.append(sum_train)
- sum_eval = HFSummary(split="valid1")
- assert len(sum_eval) > 0
- evals["hfsummary"] = sum_eval
- train = ConcatDataset(train_datasets)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
-
+ train, evals = get_datasets(training_conf["datasets"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
else:
@@ -224,8 +180,9 @@ if __name__ == "__main__":
model=model,
model_name=model_name,
args=args,
+ loss_function=training_conf["loss"],
train_dataset=train,
- eval_dataset=eval,
+ eval_dataset=evals,
data_collator=collate_fn,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
diff --git a/model/reward/instructor/utils.py b/model/reward/instructor/utils.py
index fe52c2ef..a6f3da4e 100644
--- a/model/reward/instructor/utils.py
+++ b/model/reward/instructor/utils.py
@@ -1,11 +1,13 @@
import re
+from typing import AnyStr, List
import yaml
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from transformers import AutoTokenizer, T5Tokenizer
-re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
+# @agoryuno contributed this
+re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
def webgpt_return_format(row):
@@ -97,6 +99,32 @@ def argument_parsing(parser):
return params
+def get_datasets(dataset_list: List[AnyStr]):
+ from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
+ from torch.utils.data import ConcatDataset
+
+ train_datasets, evals = [], {}
+ for dataset_name in dataset_list:
+ if "webgpt" == dataset_name:
+ web_dataset = WebGPT()
+ train, eval = train_val_dataset(web_dataset, 0.2)
+ train_datasets.append(train)
+ evals["webgpt"] = eval
+ elif "hfsummary" == dataset_name:
+ sum_train = HFSummary(split="train")
+ train_datasets.append(sum_train)
+ sum_eval = HFSummary(split="valid1")
+ assert len(sum_eval) > 0
+ evals["hfsummary"] = sum_eval
+ elif "gptsynthetic" == dataset_name:
+ dataset = GPTJSynthetic()
+ train, eval = train_val_dataset(dataset, 0.1)
+ train_datasets.append(train)
+ evals["gptsynthetic"] = eval
+ train = ConcatDataset(train_datasets)
+ return train, evals
+
+
if __name__ == "__main__":
from transformers import AutoModelForSequenceClassification
diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py
index cbfbeca5..3372cafa 100644
--- a/oasst-shared/oasst_shared/schemas/protocol.py
+++ b/oasst-shared/oasst_shared/schemas/protocol.py
@@ -272,36 +272,29 @@ class TextLabel(str, enum.Enum):
obj.help_text = help_text
return obj
- spam = "spam"
+ spam = "spam", "Seems to be intentionally low-quality or irrelevant"
fails_task = "fails_task", "Fails to follow the correct instruction / task"
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
- harmful = (
- "harmful",
- "Harmful content",
- "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
+ excessive_harm = (
+ "excessive_harm",
+ "Content likely to cause excessive harm not justifiable in the context",
+ "Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
)
sexual_content = "sexual_content", "Contains sexual content"
- toxicity = "toxicity"
+ toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
moral_judgement = "moral_judgement", "Expresses moral judgement"
- political_content = "political_content"
- humor = "humor"
- sarcasm = "sarcasm"
- hate_speech = "hate_speech"
- profanity = "profanity"
- ad_hominem = "ad_hominem"
- insult = "insult"
- threat = "threat"
- aggressive = "aggressive"
- misleading = "misleading"
- helpful = "helpful"
- formal = "formal"
- cringe = "cringe"
- creative = "creative"
- beautiful = "beautiful"
- informative = "informative"
- based = "based"
- slang = "slang"
+ political_content = "political_content", "Expresses political views"
+ humor = "humor", "Contains humorous content including sarcasm"
+ hate_speech = (
+ "hate_speech",
+ "Content is abusive or threatening and expresses prejudice against a protected characteristic",
+ "Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
+ )
+ threat = "threat", "Contains a threat against a person or persons"
+ misleading = "misleading", "Contains text which is incorrect or misleading"
+ helpful = "helpful", "Completes the task to a high standard"
+ creative = "creative", "Expresses creativity in responding to the task"
class TextLabels(Interaction):
diff --git a/website/.eslintrc.json b/website/.eslintrc.json
index 04b5d542..690c055c 100644
--- a/website/.eslintrc.json
+++ b/website/.eslintrc.json
@@ -8,7 +8,8 @@
"rules": {
"unused-imports/no-unused-imports": "warn",
"simple-import-sort/imports": "warn",
- "simple-import-sort/exports": "warn"
+ "simple-import-sort/exports": "warn",
+ "eqeqeq": "warn"
},
"plugins": ["simple-import-sort", "unused-imports"]
}
diff --git a/website/.gitignore b/website/.gitignore
index 86e167da..0b9f1dbe 100644
--- a/website/.gitignore
+++ b/website/.gitignore
@@ -39,5 +39,7 @@ next-env.d.ts
*.swp
# cypress
+/cypress/screenshots
+/cypress/videos
/cypress-visual-screenshots/diff
/cypress-visual-screenshots/comparison
diff --git a/website/README.md b/website/README.md
index 11e3ccc4..37f6991b 100644
--- a/website/README.md
+++ b/website/README.md
@@ -153,6 +153,34 @@ When writing code for the website, we have a few best practices:
1. Define functional React components (with types for all properties when
feasible).
+### Developing New Features
+
+When working on new features or making significant changes that can't be done
+within a single Pull Request, we ask that you make use of Feature Flags.
+
+We've set up
+[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to
+make this easier. To get started:
+
+1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you
+ can copy as an example. Be sure to `isActive` to true when testing your
+ features but false when submitting your PR.
+1. Use your flag wherever you add a new UI element. This can be done with:
+
+```js
+import { Flags } from "react-feature-flags";
+...
+
+
+
+```
+
+ You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`.
+
+1. Once you've finished building out the feature and it is ready for everyone
+ to use, it's safe to remove the `Flag` wrappers around your component and
+ the entry in `flags.ts`.
+
### URL Paths
To use stable and consistent URL paths, we recommend the following strategy for
diff --git a/website/package-lock.json b/website/package-lock.json
index 8d94e779..81a9cd5d 100644
--- a/website/package-lock.json
+++ b/website/package-lock.json
@@ -29,6 +29,7 @@
"eslint-config-next": "13.0.6",
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
+ "formik": "^2.2.9",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"next": "13.0.6",
@@ -38,6 +39,7 @@
"postcss-focus-visible": "^7.1.0",
"react": "18.2.0",
"react-dom": "18.2.0",
+ "react-feature-flags": "^1.0.0",
"react-icons": "^4.7.1",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
@@ -56,7 +58,7 @@
"@storybook/manager-webpack5": "^6.5.15",
"@storybook/react": "^6.5.15",
"@storybook/testing-library": "^0.0.13",
- "@types/node": "18.11.17",
+ "@types/node": "^18.11.17",
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
@@ -66,7 +68,8 @@
"eslint-plugin-unused-imports": "^2.0.0",
"prettier": "2.8.1",
"prisma": "^4.7.1",
- "typescript": "4.9.4"
+ "ts-node": "^10.9.1",
+ "typescript": "^4.9.4"
}
},
"node_modules/@ampproject/remapping": {
@@ -3122,6 +3125,28 @@
"node": ">=0.1.90"
}
},
+ "node_modules/@cspotcode/source-map-support": {
+ "version": "0.8.1",
+ "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
+ "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
+ "devOptional": true,
+ "dependencies": {
+ "@jridgewell/trace-mapping": "0.3.9"
+ },
+ "engines": {
+ "node": ">=12"
+ }
+ },
+ "node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": {
+ "version": "0.3.9",
+ "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
+ "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
+ "devOptional": true,
+ "dependencies": {
+ "@jridgewell/resolve-uri": "^3.0.3",
+ "@jridgewell/sourcemap-codec": "^1.4.10"
+ }
+ },
"node_modules/@cypress/request": {
"version": "2.88.10",
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
@@ -10836,6 +10861,30 @@
"@testing-library/dom": ">=7.21.4"
}
},
+ "node_modules/@tsconfig/node10": {
+ "version": "1.0.9",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
+ "integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
+ "devOptional": true
+ },
+ "node_modules/@tsconfig/node12": {
+ "version": "1.0.11",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
+ "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
+ "devOptional": true
+ },
+ "node_modules/@tsconfig/node14": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
+ "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
+ "devOptional": true
+ },
+ "node_modules/@tsconfig/node16": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
+ "integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
+ "devOptional": true
+ },
"node_modules/@types/aria-query": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
@@ -10975,7 +11024,7 @@
"version": "18.11.17",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
- "dev": true
+ "devOptional": true
},
"node_modules/@types/node-fetch": {
"version": "2.6.2",
@@ -12013,7 +12062,7 @@
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
- "dev": true
+ "devOptional": true
},
"node_modules/argparse": {
"version": "1.0.10",
@@ -14691,6 +14740,12 @@
"sha.js": "^2.4.8"
}
},
+ "node_modules/create-require": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
+ "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
+ "devOptional": true
+ },
"node_modules/cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -15462,6 +15517,15 @@
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
},
+ "node_modules/diff": {
+ "version": "4.0.2",
+ "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
+ "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
+ "devOptional": true,
+ "engines": {
+ "node": ">=0.3.1"
+ }
+ },
"node_modules/diffie-hellman": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
@@ -17688,6 +17752,47 @@
"node": ">= 6"
}
},
+ "node_modules/formik": {
+ "version": "2.2.9",
+ "resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
+ "integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
+ "funding": [
+ {
+ "type": "individual",
+ "url": "https://opencollective.com/formik"
+ }
+ ],
+ "dependencies": {
+ "deepmerge": "^2.1.1",
+ "hoist-non-react-statics": "^3.3.0",
+ "lodash": "^4.17.21",
+ "lodash-es": "^4.17.21",
+ "react-fast-compare": "^2.0.1",
+ "tiny-warning": "^1.0.2",
+ "tslib": "^1.10.0"
+ },
+ "peerDependencies": {
+ "react": ">=16.8.0"
+ }
+ },
+ "node_modules/formik/node_modules/deepmerge": {
+ "version": "2.2.1",
+ "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
+ "integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/formik/node_modules/react-fast-compare": {
+ "version": "2.0.4",
+ "resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
+ "integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
+ },
+ "node_modules/formik/node_modules/tslib": {
+ "version": "1.14.1",
+ "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
+ "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
+ },
"node_modules/forwarded": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
@@ -20409,8 +20514,12 @@
"node_modules/lodash": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
- "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
- "dev": true
+ "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
+ },
+ "node_modules/lodash-es": {
+ "version": "4.17.21",
+ "resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
+ "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
},
"node_modules/lodash.debounce": {
"version": "4.0.8",
@@ -20690,6 +20799,12 @@
"semver": "bin/semver"
}
},
+ "node_modules/make-error": {
+ "version": "1.3.6",
+ "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
+ "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
+ "devOptional": true
+ },
"node_modules/makeerror": {
"version": "1.0.12",
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
@@ -26310,6 +26425,16 @@
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
},
+ "node_modules/react-feature-flags": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
+ "integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
+ "peerDependencies": {
+ "prop-types": "^15.5.4",
+ "react": ">= 16.3.0",
+ "react-dom": ">= 16.3.0"
+ }
+ },
"node_modules/react-focus-lock": {
"version": "2.9.2",
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
@@ -29081,6 +29206,11 @@
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
},
+ "node_modules/tiny-warning": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
+ "integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
+ },
"node_modules/tmp": {
"version": "0.2.1",
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
@@ -29247,6 +29377,70 @@
"node": ">=6.10"
}
},
+ "node_modules/ts-node": {
+ "version": "10.9.1",
+ "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
+ "integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
+ "devOptional": true,
+ "dependencies": {
+ "@cspotcode/source-map-support": "^0.8.0",
+ "@tsconfig/node10": "^1.0.7",
+ "@tsconfig/node12": "^1.0.7",
+ "@tsconfig/node14": "^1.0.0",
+ "@tsconfig/node16": "^1.0.2",
+ "acorn": "^8.4.1",
+ "acorn-walk": "^8.1.1",
+ "arg": "^4.1.0",
+ "create-require": "^1.1.0",
+ "diff": "^4.0.1",
+ "make-error": "^1.1.1",
+ "v8-compile-cache-lib": "^3.0.1",
+ "yn": "3.1.1"
+ },
+ "bin": {
+ "ts-node": "dist/bin.js",
+ "ts-node-cwd": "dist/bin-cwd.js",
+ "ts-node-esm": "dist/bin-esm.js",
+ "ts-node-script": "dist/bin-script.js",
+ "ts-node-transpile-only": "dist/bin-transpile.js",
+ "ts-script": "dist/bin-script-deprecated.js"
+ },
+ "peerDependencies": {
+ "@swc/core": ">=1.2.50",
+ "@swc/wasm": ">=1.2.50",
+ "@types/node": "*",
+ "typescript": ">=2.7"
+ },
+ "peerDependenciesMeta": {
+ "@swc/core": {
+ "optional": true
+ },
+ "@swc/wasm": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/ts-node/node_modules/acorn": {
+ "version": "8.8.1",
+ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
+ "integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
+ "devOptional": true,
+ "bin": {
+ "acorn": "bin/acorn"
+ },
+ "engines": {
+ "node": ">=0.4.0"
+ }
+ },
+ "node_modules/ts-node/node_modules/acorn-walk": {
+ "version": "8.2.0",
+ "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
+ "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
+ "devOptional": true,
+ "engines": {
+ "node": ">=0.4.0"
+ }
+ },
"node_modules/ts-pnp": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
@@ -29951,6 +30145,12 @@
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
"dev": true
},
+ "node_modules/v8-compile-cache-lib": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
+ "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
+ "devOptional": true
+ },
"node_modules/v8-to-istanbul": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
@@ -30850,6 +31050,15 @@
"fd-slicer": "~1.1.0"
}
},
+ "node_modules/yn": {
+ "version": "3.1.1",
+ "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
+ "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
+ "devOptional": true,
+ "engines": {
+ "node": ">=6"
+ }
+ },
"node_modules/yocto-queue": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
@@ -33044,6 +33253,27 @@
"dev": true,
"optional": true
},
+ "@cspotcode/source-map-support": {
+ "version": "0.8.1",
+ "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
+ "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
+ "devOptional": true,
+ "requires": {
+ "@jridgewell/trace-mapping": "0.3.9"
+ },
+ "dependencies": {
+ "@jridgewell/trace-mapping": {
+ "version": "0.3.9",
+ "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
+ "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
+ "devOptional": true,
+ "requires": {
+ "@jridgewell/resolve-uri": "^3.0.3",
+ "@jridgewell/sourcemap-codec": "^1.4.10"
+ }
+ }
+ }
+ },
"@cypress/request": {
"version": "2.88.10",
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
@@ -38884,6 +39114,30 @@
"@babel/runtime": "^7.12.5"
}
},
+ "@tsconfig/node10": {
+ "version": "1.0.9",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
+ "integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
+ "devOptional": true
+ },
+ "@tsconfig/node12": {
+ "version": "1.0.11",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
+ "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
+ "devOptional": true
+ },
+ "@tsconfig/node14": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
+ "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
+ "devOptional": true
+ },
+ "@tsconfig/node16": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
+ "integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
+ "devOptional": true
+ },
"@types/aria-query": {
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
@@ -39023,7 +39277,7 @@
"version": "18.11.17",
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
- "dev": true
+ "devOptional": true
},
"@types/node-fetch": {
"version": "2.6.2",
@@ -39875,7 +40129,7 @@
"version": "4.1.3",
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
- "dev": true
+ "devOptional": true
},
"argparse": {
"version": "1.0.10",
@@ -41964,6 +42218,12 @@
"sha.js": "^2.4.8"
}
},
+ "create-require": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
+ "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
+ "devOptional": true
+ },
"cross-spawn": {
"version": "7.0.3",
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
@@ -42547,6 +42807,12 @@
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
},
+ "diff": {
+ "version": "4.0.2",
+ "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
+ "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
+ "devOptional": true
+ },
"diffie-hellman": {
"version": "5.0.3",
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
@@ -44293,6 +44559,37 @@
"mime-types": "^2.1.12"
}
},
+ "formik": {
+ "version": "2.2.9",
+ "resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
+ "integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
+ "requires": {
+ "deepmerge": "^2.1.1",
+ "hoist-non-react-statics": "^3.3.0",
+ "lodash": "^4.17.21",
+ "lodash-es": "^4.17.21",
+ "react-fast-compare": "^2.0.1",
+ "tiny-warning": "^1.0.2",
+ "tslib": "^1.10.0"
+ },
+ "dependencies": {
+ "deepmerge": {
+ "version": "2.2.1",
+ "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
+ "integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA=="
+ },
+ "react-fast-compare": {
+ "version": "2.0.4",
+ "resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
+ "integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
+ },
+ "tslib": {
+ "version": "1.14.1",
+ "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
+ "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
+ }
+ }
+ },
"forwarded": {
"version": "0.2.0",
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
@@ -46307,8 +46604,12 @@
"lodash": {
"version": "4.17.21",
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
- "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
- "dev": true
+ "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
+ },
+ "lodash-es": {
+ "version": "4.17.21",
+ "resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
+ "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
},
"lodash.debounce": {
"version": "4.0.8",
@@ -46525,6 +46826,12 @@
}
}
},
+ "make-error": {
+ "version": "1.3.6",
+ "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
+ "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
+ "devOptional": true
+ },
"makeerror": {
"version": "1.0.12",
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
@@ -50542,6 +50849,12 @@
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
},
+ "react-feature-flags": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
+ "integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
+ "requires": {}
+ },
"react-focus-lock": {
"version": "2.9.2",
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
@@ -52722,6 +53035,11 @@
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
},
+ "tiny-warning": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
+ "integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
+ },
"tmp": {
"version": "0.2.1",
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
@@ -52852,6 +53170,41 @@
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==",
"dev": true
},
+ "ts-node": {
+ "version": "10.9.1",
+ "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
+ "integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
+ "devOptional": true,
+ "requires": {
+ "@cspotcode/source-map-support": "^0.8.0",
+ "@tsconfig/node10": "^1.0.7",
+ "@tsconfig/node12": "^1.0.7",
+ "@tsconfig/node14": "^1.0.0",
+ "@tsconfig/node16": "^1.0.2",
+ "acorn": "^8.4.1",
+ "acorn-walk": "^8.1.1",
+ "arg": "^4.1.0",
+ "create-require": "^1.1.0",
+ "diff": "^4.0.1",
+ "make-error": "^1.1.1",
+ "v8-compile-cache-lib": "^3.0.1",
+ "yn": "3.1.1"
+ },
+ "dependencies": {
+ "acorn": {
+ "version": "8.8.1",
+ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
+ "integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
+ "devOptional": true
+ },
+ "acorn-walk": {
+ "version": "8.2.0",
+ "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
+ "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
+ "devOptional": true
+ }
+ }
+ },
"ts-pnp": {
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
@@ -53362,6 +53715,12 @@
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
"dev": true
},
+ "v8-compile-cache-lib": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
+ "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
+ "devOptional": true
+ },
"v8-to-istanbul": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
@@ -54083,6 +54442,12 @@
"fd-slicer": "~1.1.0"
}
},
+ "yn": {
+ "version": "3.1.1",
+ "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
+ "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
+ "devOptional": true
+ },
"yocto-queue": {
"version": "0.1.0",
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
diff --git a/website/package.json b/website/package.json
index eb85568b..032125dd 100644
--- a/website/package.json
+++ b/website/package.json
@@ -18,6 +18,9 @@
"fix:format": "prettier --write ./src",
"fix": "npm run fix:format && npm run fix:lint"
},
+ "prisma": {
+ "seed": "ts-node --compiler-options {\"module\":\"CommonJS\"} prisma/seed.ts"
+ },
"dependencies": {
"@chakra-ui/react": "^2.4.4",
"@dnd-kit/core": "^6.0.6",
@@ -40,6 +43,7 @@
"eslint-config-next": "13.0.6",
"eslint-plugin-simple-import-sort": "^8.0.0",
"focus-visible": "^5.2.0",
+ "formik": "^2.2.9",
"framer-motion": "^6.5.1",
"install": "^0.13.0",
"next": "13.0.6",
@@ -49,6 +53,7 @@
"postcss-focus-visible": "^7.1.0",
"react": "18.2.0",
"react-dom": "18.2.0",
+ "react-feature-flags": "^1.0.0",
"react-icons": "^4.7.1",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
@@ -67,7 +72,7 @@
"@storybook/manager-webpack5": "^6.5.15",
"@storybook/react": "^6.5.15",
"@storybook/testing-library": "^0.0.13",
- "@types/node": "18.11.17",
+ "@types/node": "^18.11.17",
"@types/react": "18.0.26",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"babel-loader": "^8.3.0",
@@ -77,6 +82,7 @@
"eslint-plugin-unused-imports": "^2.0.0",
"prettier": "2.8.1",
"prisma": "^4.7.1",
- "typescript": "4.9.4"
+ "ts-node": "^10.9.1",
+ "typescript": "^4.9.4"
}
}
diff --git a/website/prisma/seed.ts b/website/prisma/seed.ts
new file mode 100644
index 00000000..0933cad4
--- /dev/null
+++ b/website/prisma/seed.ts
@@ -0,0 +1,57 @@
+/**
+ * A seed function to inject test data into the web database.
+ *
+ * Use by running
+ * npx prisma db seed
+ */
+
+import { PrismaClient } from "@prisma/client";
+const prisma = new PrismaClient();
+
+async function main() {
+ const users = [
+ { email: "general.user.a@example.com", name: "A", role: "general" },
+ { email: "general.user.b@example.com", name: "B", role: "general" },
+ { email: "general.user.c@example.com", name: "C", role: "general" },
+ { email: "general.user.d@example.com", name: "D", role: "general" },
+ { email: "general.user.e@example.com", name: "E", role: "general" },
+ { email: "general.user.f@example.com", name: "F", role: "general" },
+ { email: "general.user.g@example.com", name: "G", role: "general" },
+ { email: "general.user.h@example.com", name: "H", role: "general" },
+ { email: "general.user.i@example.com", name: "I", role: "general" },
+ { email: "general.user.j@example.com", name: "J", role: "general" },
+ { email: "general.user.k@example.com", name: "K", role: "general" },
+ { email: "general.user.l@example.com", name: "L", role: "general" },
+ { email: "general.user.m@example.com", name: "M", role: "general" },
+ { email: "general.user.n@example.com", name: "N", role: "general" },
+ { email: "general.user.o@example.com", name: "O", role: "general" },
+ { email: "general.user.p@example.com", name: "P", role: "general" },
+ { email: "general.user.q@example.com", name: "Q", role: "general" },
+ { email: "general.user.r@example.com", name: "R", role: "general" },
+ { email: "malicious.user.1@example.com", name: "M1", role: "general" },
+ { email: "malicious.user.2@example.com", name: "M2", role: "general" },
+ ];
+ await Promise.all(
+ users.map(async ({ email, name, role }) => {
+ await prisma.user.upsert({
+ where: { email },
+ update: { name, role },
+ create: {
+ email,
+ name,
+ role,
+ },
+ });
+ })
+ );
+}
+
+main()
+ .then(async () => {
+ await prisma.$disconnect();
+ })
+ .catch(async (e) => {
+ console.error(e);
+ await prisma.$disconnect();
+ process.exit(1);
+ });
diff --git a/website/src/components/Buttons/Skip.tsx b/website/src/components/Buttons/Skip.tsx
index 74ca0926..8440e348 100644
--- a/website/src/components/Buttons/Skip.tsx
+++ b/website/src/components/Buttons/Skip.tsx
@@ -1,9 +1,63 @@
-import { Button, ButtonProps } from "@chakra-ui/react";
+import {
+ Button,
+ ButtonProps,
+ Menu,
+ MenuButton,
+ MenuItem,
+ MenuList,
+ Modal,
+ ModalBody,
+ ModalCloseButton,
+ ModalContent,
+ ModalFooter,
+ ModalHeader,
+ ModalOverlay,
+ Textarea,
+ useDisclosure,
+} from "@chakra-ui/react";
+import { useState } from "react";
+import { FaChevronDown } from "react-icons/fa";
+
+interface SkipButtonProps extends ButtonProps {
+ onSkip: (reason: string) => void;
+}
+
+export const SkipButton = ({ onSkip, ...props }: SkipButtonProps) => {
+ const { isOpen, onOpen: showModal, onClose: closeModal } = useDisclosure();
+ const [value, setValue] = useState("");
+
+ const onSubmit = () => {
+ onSkip(value);
+ setValue("");
+ closeModal();
+ };
-export const SkipButton = ({ children, ...props }: ButtonProps) => {
return (
-
+ <>
+
+
+
+
+ Skip
+
+
+
+
+
+
+
+
+
+ >
);
};
diff --git a/website/src/components/Dashboard/TaskOption.tsx b/website/src/components/Dashboard/TaskOption.tsx
index 5e6ceb2f..1c070e17 100644
--- a/website/src/components/Dashboard/TaskOption.tsx
+++ b/website/src/components/Dashboard/TaskOption.tsx
@@ -3,7 +3,7 @@ import Link from "next/link";
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
-const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate];
+const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
export const TaskOption = () => {
const backgroundColor = useColorModeValue("white", "gray.700");
@@ -12,9 +12,9 @@ export const TaskOption = () => {
{displayTaskCategories.map((category, categoryIndex) => (
- {TaskCategory[category]}
+ {category}
- {TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => (
+ {TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
{
const [isEditing, setIsEditing] = useBoolean();
- const { trigger } = useSWRMutation("/api/v1/text_labels", poster, {
+ const { trigger } = useSWRMutation("/api/set_label", poster, {
onSuccess: () => {
setIsEditing.off;
},
@@ -42,7 +42,12 @@ export const FlaggableElement = (props) => {
label_map.set(flag.attributeName, sliderValues[i]);
}
});
- trigger({ post_id: props.post_id, label_map: Object.fromEntries(label_map), text: props.text });
+ trigger({
+ message_id: props.message_id,
+ post_id: props.post_id,
+ label_map: Object.fromEntries(label_map),
+ text: props.text,
+ });
};
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
@@ -118,7 +123,8 @@ export const FlaggableElement = (props) => {
);
};
-function FlagCheckbox(props: {
+
+export function FlagCheckbox(props: {
option: textFlagLabels;
idx: number;
checkboxValues: boolean[];
@@ -183,40 +189,40 @@ interface textFlagLabels {
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
// For the time being this list is configured on the FE.
// In the future it may be provided by the API.
+ // {
+ // attributeName: "fails_task",
+ // labelText: "Fails to follow the correct instruction / task",
+ // additionalExplanation: "__TODO__",
+ // },
+ // {
+ // attributeName: "not_customer_assistant_appropriate",
+ // labelText: "Inappropriate for customer assistant",
+ // additionalExplanation: "__TODO__",
+ // },
{
- attributeName: "fails_task",
- labelText: "Fails to follow the correct instruction / task",
- additionalExplanation: "__TODO__",
- },
- {
- attributeName: "not_customer_assistant_appropriate",
- labelText: "Inappropriate for customer assistant",
- additionalExplanation: "__TODO__",
- },
- {
- attributeName: "contains_sexual_content",
+ attributeName: "sexual_content",
labelText: "Contains sexual content",
},
{
- attributeName: "contains_violent_content",
+ attributeName: "violence",
labelText: "Contains violent content",
},
- {
- attributeName: "encourages_violence",
- labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
- },
- {
- attributeName: "denigrates_a_protected_class",
- labelText: "Denigrates a protected class",
- },
- {
- attributeName: "gives_harmful_advice",
- labelText: "Fails to follow the correct instruction / task",
- additionalExplanation:
- "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
- },
- {
- attributeName: "expresses_moral_judgement",
- labelText: "Expresses moral judgement",
- },
+ // {
+ // attributeName: "encourages_violence",
+ // labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
+ // },
+ // {
+ // attributeName: "denigrates_a_protected_class",
+ // labelText: "Denigrates a protected class",
+ // },
+ // {
+ // attributeName: "gives_harmful_advice",
+ // labelText: "Fails to follow the correct instruction / task",
+ // additionalExplanation:
+ // "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
+ // },
+ // {
+ // attributeName: "expresses_moral_judgement",
+ // labelText: "Expresses moral judgement",
+ // },
];
diff --git a/website/src/components/Header/Header.tsx b/website/src/components/Header/Header.tsx
index 4ea453c6..d0dc8a06 100644
--- a/website/src/components/Header/Header.tsx
+++ b/website/src/components/Header/Header.tsx
@@ -2,6 +2,7 @@ import { Box, Button, Text, useColorMode } from "@chakra-ui/react";
import Image from "next/image";
import Link from "next/link";
import { useSession } from "next-auth/react";
+import { Flags } from "react-feature-flags";
import { FaUser } from "react-icons/fa";
import { UserMenu } from "./UserMenu";
@@ -42,6 +43,9 @@ export function Header(props) {
diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx
index 1faefcc0..c4450f58 100644
--- a/website/src/components/Layout.tsx
+++ b/website/src/components/Layout.tsx
@@ -1,7 +1,7 @@
// https://nextjs.org/docs/basic-features/layouts
import type { NextPage } from "next";
-import { FiLayout, FiMessageSquare } from "react-icons/fi";
+import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi";
import { Header } from "src/components/Header";
import { Footer } from "./Footer";
@@ -51,4 +51,22 @@ export const getDashboardLayout = (page: React.ReactElement) => (
);
+export const getAdminLayout = (page: React.ReactElement) => (
+
+
+
+ {page}
+
+
+);
+
export const noLayout = (page: React.ReactElement) => page;
diff --git a/website/src/components/Loading/LoadingScreen.jsx b/website/src/components/Loading/LoadingScreen.jsx
index 02aabe7a..3595b3c4 100644
--- a/website/src/components/Loading/LoadingScreen.jsx
+++ b/website/src/components/Loading/LoadingScreen.jsx
@@ -1,7 +1,7 @@
import { Progress } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
-export const LoadingScreen = ({ text }) => {
+export const LoadingScreen = ({ text = "Loading..." } = {}) => {
const { colorMode } = useColorMode();
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx
index d3d7b3b8..fb84559e 100644
--- a/website/src/components/Messages.tsx
+++ b/website/src/components/Messages.tsx
@@ -1,36 +1,38 @@
import { Grid } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
+import { useMemo } from "react";
import { FlaggableElement } from "./FlaggableElement";
export interface Message {
text: string;
is_assistant: boolean;
+ message_id: string;
}
-const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
- if (colorMode === "light") {
- return isAssistant ? "bg-slate-800" : "bg-sky-900";
- } else {
- return isAssistant ? "bg-black" : "bg-sky-900";
- }
-};
-
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
- const { colorMode } = useColorMode();
-
- const items = messages.map(({ text, is_assistant }: Message, i: number) => {
+ const items = messages.map((messageProps: Message, i: number) => {
+ const { message_id, text } = messageProps;
return (
-
-
- {text}
-
+
+
);
});
// Maybe also show a legend of the colors?
return {items};
};
+
+export const MessageView = ({ is_assistant, text, message_id }: Message) => {
+ const { colorMode } = useColorMode();
+
+ const bgColor = useMemo(() => {
+ if (colorMode === "light") {
+ return is_assistant ? "bg-slate-800" : "bg-sky-900";
+ } else {
+ return is_assistant ? "bg-black" : "bg-sky-900";
+ }
+ }, [colorMode, is_assistant]);
+
+ return {text}
;
+};
diff --git a/website/src/components/Survey/TaskControls.tsx b/website/src/components/Survey/TaskControls.tsx
index 851e659c..7f419fcb 100644
--- a/website/src/components/Survey/TaskControls.tsx
+++ b/website/src/components/Survey/TaskControls.tsx
@@ -1,5 +1,6 @@
import { useColorMode } from "@chakra-ui/react";
import { Flex } from "@chakra-ui/react";
+import clsx from "clsx";
import { SkipButton } from "src/components/Buttons/Skip";
import { SubmitButton } from "src/components/Buttons/Submit";
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
@@ -10,31 +11,38 @@ export interface TaskControlsProps {
tasks: any[];
className?: string;
onSubmitResponse: (task: { id: string }) => void;
- onSkip: () => void;
+ onSkipTask: (task: { id: string }, reason: string) => void;
+ onNextTask: () => void;
}
export const TaskControls = (props: TaskControlsProps) => {
- const extraClases = props.className || "";
const { colorMode } = useColorMode();
-
- const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
- const taskControlClases =
- colorMode === "light"
- ? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
- : `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
-
+ const isLightMode = colorMode === "light";
const endTask = props.tasks[props.tasks.length - 1];
return (
-
+
- Skip
+ {
+ props.onSkipTask(props.tasks[0], reason);
+ }}
+ />
{endTask.task.type !== "task_done" ? (
props.onSubmitResponse(props.tasks[0])}>
Submit
) : (
-
+
Next Task
)}
diff --git a/website/src/components/Survey/TrackedTextarea.tsx b/website/src/components/Survey/TrackedTextarea.tsx
index f1691b72..d20107ac 100644
--- a/website/src/components/Survey/TrackedTextarea.tsx
+++ b/website/src/components/Survey/TrackedTextarea.tsx
@@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => {
return (
-
+
);
diff --git a/website/src/components/TaskInfo/TaskInfo.tsx b/website/src/components/TaskInfo/TaskInfo.tsx
index 86fd2d96..c692e172 100644
--- a/website/src/components/TaskInfo/TaskInfo.tsx
+++ b/website/src/components/TaskInfo/TaskInfo.tsx
@@ -1,6 +1,6 @@
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
return (
-
+
Prompt
{id}
Output
diff --git a/website/src/components/TaskSelection/TaskOption.tsx b/website/src/components/TaskSelection/TaskOption.tsx
deleted file mode 100644
index 764efa68..00000000
--- a/website/src/components/TaskSelection/TaskOption.tsx
+++ /dev/null
@@ -1,39 +0,0 @@
-import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
-import Image from "next/image";
-import Link from "next/link";
-
-export type OptionProps = {
- img: string;
- alt: string;
- title: string;
- link: string;
-};
-
-export const TaskOption = (props: OptionProps) => {
- const { alt, img, title, link } = props;
- return (
-
-
-
-
-
-
- {title}
-
-
-
-
-
- );
-};
diff --git a/website/src/components/TaskSelection/TaskOptions.tsx b/website/src/components/TaskSelection/TaskOptions.tsx
deleted file mode 100644
index fe24b393..00000000
--- a/website/src/components/TaskSelection/TaskOptions.tsx
+++ /dev/null
@@ -1,23 +0,0 @@
-import { Divider, Flex, Heading } from "@chakra-ui/react";
-import React from "react";
-
-export type TaskOptionsProps = {
- title: string;
- children: JSX.Element | JSX.Element[];
-};
-
-export const TaskOptions = (props: TaskOptionsProps) => {
- const { title, children } = props;
- return (
-
-
- {title}
-
-
- {children}
-
- );
-};
diff --git a/website/src/components/TaskSelection/TaskSelection.tsx b/website/src/components/TaskSelection/TaskSelection.tsx
deleted file mode 100644
index 683c80e9..00000000
--- a/website/src/components/TaskSelection/TaskSelection.tsx
+++ /dev/null
@@ -1,73 +0,0 @@
-import { Flex } from "@chakra-ui/react";
-import { useColorMode } from "@chakra-ui/react";
-import React from "react";
-
-import { TaskOption } from "./TaskOption";
-import { TaskOptions } from "./TaskOptions";
-
-export const TaskSelection = () => {
- const { colorMode } = useColorMode();
- const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
-
- return (
-
-
- {/* */}
-
-
-
-
-
- {/*
- Commented out while the backend does not support them.
- */}
-
-
-
-
-
- );
-};
diff --git a/website/src/components/TaskSelection/index.ts b/website/src/components/TaskSelection/index.ts
deleted file mode 100644
index d6d93973..00000000
--- a/website/src/components/TaskSelection/index.ts
+++ /dev/null
@@ -1,3 +0,0 @@
-export { TaskOption } from "./TaskOption";
-export { TaskOptions } from "./TaskOptions";
-export { TaskSelection } from "./TaskSelection";
diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx
index 057177d2..7dcb0d0f 100644
--- a/website/src/components/Tasks/CreateTask.tsx
+++ b/website/src/components/Tasks/CreateTask.tsx
@@ -1,10 +1,22 @@
import { useState } from "react";
+
import { Messages } from "src/components/Messages";
import { TaskControls } from "src/components/Survey/TaskControls";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
+import { TaskType } from "./TaskTypes";
-export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
+export interface CreateTaskProps {
+ // we need a task type
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ tasks: any[];
+ taskType: TaskType;
+ trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
+ onSkipTask: (task: { id: string }, reason: string) => void;
+ onNextTask: () => void;
+ mainBgClasses: string;
+}
+export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
const task = tasks[0].task;
const [inputText, setInputText] = useState("");
@@ -20,11 +32,6 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
});
};
- const fetchNextTask = () => {
- setInputText("");
- mutate();
- };
-
const textChangeHandler = (event: React.ChangeEvent
) => {
setInputText(event.target.value);
};
@@ -48,7 +55,15 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
>
-
+ {
+ setInputText("");
+ onSkipTask(task, reason);
+ }}
+ onNextTask={onNextTask}
+ />
);
};
diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx
index c45d3dbe..3871b2d9 100644
--- a/website/src/components/Tasks/EvaluateTask.tsx
+++ b/website/src/components/Tasks/EvaluateTask.tsx
@@ -5,7 +5,17 @@ import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverr
import { MessageTable } from "../Messages/MessageTable";
-export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
+export interface EvaluateTaskProps {
+ // we need a task type
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ tasks: any[];
+ trigger: (update: { id: string; update_type: string; content: { ranking: number[] } }) => void;
+ onSkipTask: (task: { id: string }, reason: string) => void;
+ onNextTask: () => void;
+ mainBgClasses: string;
+}
+
+export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgClasses }: EvaluateTaskProps) => {
const [ranking, setRanking] = useState
([]);
const submitResponse = (task) => {
trigger({
@@ -17,10 +27,6 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
});
};
- const fetchNextTask = () => {
- setRanking([]);
- mutate();
- };
let messages = null;
if (tasks[0].task.conversation) {
messages = tasks[0].task.conversation.messages;
@@ -45,7 +51,11 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
isValid={ranking.length == tasks[0].task[sortables].length}
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
onSubmitResponse={submitResponse}
- onSkip={fetchNextTask}
+ onSkipTask={(task, reason) => {
+ setRanking([]);
+ onSkipTask(task, reason);
+ }}
+ onNextTask={onNextTask}
/>
);
diff --git a/website/src/components/Tasks/LabelTask.tsx b/website/src/components/Tasks/LabelTask.tsx
new file mode 100644
index 00000000..bb9d417c
--- /dev/null
+++ b/website/src/components/Tasks/LabelTask.tsx
@@ -0,0 +1,100 @@
+import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
+import { useColorMode } from "@chakra-ui/react";
+import { ReactNode, useEffect, useId, useMemo, useState } from "react";
+import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
+import { colors } from "styles/Theme/colors";
+
+export const LabelTask = ({
+ title,
+ desc,
+ messages,
+ inputs,
+ controls,
+}: {
+ title: string;
+ desc: string;
+ messages: ReactNode;
+ inputs: ReactNode;
+ controls: ReactNode;
+}) => {
+ const { colorMode } = useColorMode();
+ const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
+
+ const card = useMemo(
+ () => (
+ <>
+ {title}
+ {desc}
+ {messages}
+ >
+ ),
+ [title, desc, messages]
+ );
+
+ return (
+
+
+ {card}
+ {inputs}
+
+ {controls}
+
+ );
+};
+
+// TODO: consolidate with FlaggableElement
+interface LabelSliderGroupProps {
+ labelIDs: Array;
+ onChange: (sliderValues: number[]) => unknown;
+}
+
+export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
+ const [sliderValues, setSliderValues] = useState(Array.from({ length: labelIDs.length }).map(() => 0));
+
+ useEffect(() => {
+ onChange(sliderValues);
+ }, [sliderValues, onChange]);
+
+ return (
+
+ {labelIDs.map((labelId, idx) => (
+ {
+ const newState = sliderValues.slice();
+ newState[idx] = sliderValue;
+ setSliderValues(newState);
+ }}
+ />
+ ))}
+
+ );
+};
+
+function CheckboxSliderItem(props: {
+ labelId: string;
+ sliderValue: number;
+ sliderHandler: (newVal: number) => unknown;
+}) {
+ const id = useId();
+ const { colorMode } = useColorMode();
+
+ const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
+
+ return (
+ <>
+
+ props.sliderHandler(val / 100)}>
+
+
+
+
+
+ >
+ );
+}
diff --git a/website/src/components/Tasks/Task.tsx b/website/src/components/Tasks/Task.tsx
index ce9505e9..153e0a93 100644
--- a/website/src/components/Tasks/Task.tsx
+++ b/website/src/components/Tasks/Task.tsx
@@ -1,10 +1,25 @@
import { CreateTask } from "./CreateTask";
import { EvaluateTask } from "./EvaluateTask";
import { TaskCategory, TaskTypes } from "./TaskTypes";
+import useSWRMutation from "swr/mutation";
+import poster from "src/lib/poster";
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
const task = tasks[0].task;
+ const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
+ onSuccess: async () => {
+ mutate();
+ },
+ });
+
+ const rejectTask = (task: { id: string }, reason: string) => {
+ sendRejection({
+ id: task.id,
+ reason,
+ });
+ };
+
function taskTypeComponent(type) {
const taskType = TaskTypes.find((taskType) => taskType.type === type);
const category = taskType.category;
@@ -14,13 +29,22 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
);
case TaskCategory.Evaluate:
- return ;
+ return (
+
+ );
}
}
diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx
index 413a1e16..409e7038 100644
--- a/website/src/components/Tasks/TaskTypes.tsx
+++ b/website/src/components/Tasks/TaskTypes.tsx
@@ -1,9 +1,21 @@
export enum TaskCategory {
- Create,
- Evaluate,
+ Create = "Create",
+ Evaluate = "Evaluate",
+ Label = "Label",
}
-export const TaskTypes = [
+export interface TaskType {
+ label: string;
+ desc: string;
+ category: TaskCategory;
+ pathname: string;
+ type: string;
+ overview?: string;
+ instruction?: string;
+}
+
+export const TaskTypes: TaskType[] = [
+ // create
{
label: "Create Initial Prompts",
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
@@ -31,6 +43,7 @@ export const TaskTypes = [
overview: "Given the following conversation, provide an adequate reply",
instruction: "Provide the assistant`s reply",
},
+ // evaluate
{
label: "Rank User Replies",
category: TaskCategory.Evaluate,
@@ -52,4 +65,26 @@ export const TaskTypes = [
pathname: "/evaluate/rank_initial_prompts",
type: "rank_initial_prompts",
},
+ // label
+ {
+ label: "Label Initial Prompt",
+ desc: "Provide labels for a prompt.",
+ category: TaskCategory.Label,
+ pathname: "/label/label_initial_prompt",
+ type: "label_initial_prompt",
+ },
+ {
+ label: "Label Prompter Reply",
+ desc: "Provide labels for a prompt.",
+ category: TaskCategory.Label,
+ pathname: "/label/label_prompter_reply",
+ type: "label_prompter_reply",
+ },
+ {
+ label: "Label Assistant Reply",
+ desc: "Provide labels for a prompt.",
+ category: TaskCategory.Label,
+ pathname: "/label/label_assistant_reply",
+ type: "label_assistant_reply",
+ },
];
diff --git a/website/src/components/UsersCell.tsx b/website/src/components/UsersCell.tsx
index b2b04c83..5354ee5c 100644
--- a/website/src/components/UsersCell.tsx
+++ b/website/src/components/UsersCell.tsx
@@ -1,4 +1,18 @@
-import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
+import {
+ Button,
+ Flex,
+ Spacer,
+ Stack,
+ Table,
+ TableCaption,
+ TableContainer,
+ Tbody,
+ Td,
+ Th,
+ Thead,
+ Tr,
+} from "@chakra-ui/react";
+import Link from "next/link";
import { useState } from "react";
import fetcher from "src/lib/fetcher";
import useSWR from "swr";
@@ -7,37 +21,60 @@ import useSWR from "swr";
* Fetches users from the users api route and then presents them in a simple Chakra table.
*/
const UsersCell = () => {
- // Fetch and save the users.
+ const [pageIndex, setPageIndex] = useState(0);
const [users, setUsers] = useState([]);
- const { isLoading } = useSWR("/api/admin/users", fetcher, {
+
+ // Fetch and save the users.
+ // This follows useSWR's recommendation for simple pagination:
+ // https://swr.vercel.app/docs/pagination#when-to-use-useswr
+ useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, {
onSuccess: setUsers,
});
+ const toPreviousPage = () => {
+ setPageIndex(Math.max(0, pageIndex - 1));
+ };
+
+ const toNextPage = () => {
+ setPageIndex(pageIndex + 1);
+ };
+
// Present users in a naive table.
return (
-
-
- Users
-
-
- | Id |
- Email |
- Name |
- Role |
-
-
-
- {users.map((user, index) => (
-
- | {user.id} |
- {user.email} |
- {user.name} |
- {user.role} |
+
+
+
+
+
+
+
+
+ Users
+
+
+ | Id |
+ Email |
+ Name |
+ Role |
+ Update |
- ))}
-
-
-
+
+
+ {users.map((user, index) => (
+
+ | {user.id} |
+ {user.email} |
+ {user.name} |
+ {user.role} |
+
+ Manage
+ |
+
+ ))}
+
+
+
+
);
};
diff --git a/website/src/flags.ts b/website/src/flags.ts
new file mode 100644
index 00000000..2e59ec1a
--- /dev/null
+++ b/website/src/flags.ts
@@ -0,0 +1,3 @@
+const flags = [{ name: "flagTest", isActive: false }];
+
+export default flags;
diff --git a/website/src/hooks/tasks/create/useCreateInitialPrompt.ts b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts
new file mode 100644
index 00000000..cf0193e8
--- /dev/null
+++ b/website/src/hooks/tasks/create/useCreateInitialPrompt.ts
@@ -0,0 +1,9 @@
+import { useGenericTaskAPI } from "../useGenericTaskAPI";
+
+interface CreateInitialPromptTask {
+ id: string;
+ type: "initial_prompt";
+ hint: string;
+}
+
+export const useCreateInitialPrompt = () => useGenericTaskAPI("initial_prompt");
diff --git a/website/src/hooks/tasks/create/useCreateReply.ts b/website/src/hooks/tasks/create/useCreateReply.ts
new file mode 100644
index 00000000..0bc78319
--- /dev/null
+++ b/website/src/hooks/tasks/create/useCreateReply.ts
@@ -0,0 +1,24 @@
+import { useGenericTaskAPI } from "../useGenericTaskAPI";
+
+interface BaseCreateReplyTask {
+ id: string;
+ conversation: {
+ messages: Array<{
+ text: string;
+ is_assistant: boolean;
+ message_id: string;
+ }>;
+ };
+}
+
+export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
+ type: "assistant_reply";
+}
+
+export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
+ type: "prompter_reply";
+}
+
+export const useCreateAssistantReply = () => useGenericTaskAPI("assistant_reply");
+
+export const useCreatePrompterReply = () => useGenericTaskAPI("prompter_reply");
diff --git a/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts
new file mode 100644
index 00000000..da772c80
--- /dev/null
+++ b/website/src/hooks/tasks/evaluate/useRankInitialPrompts.ts
@@ -0,0 +1,9 @@
+import { useGenericTaskAPI } from "../useGenericTaskAPI";
+
+interface RankInitialPromptsTask {
+ id: string;
+ type: "rank_initial_prompts";
+ prompts: string[];
+}
+
+export const useRankInitialPromptsTask = () => useGenericTaskAPI("rank_initial_prompts");
diff --git a/website/src/hooks/tasks/evaluate/useRankReplies.ts b/website/src/hooks/tasks/evaluate/useRankReplies.ts
new file mode 100644
index 00000000..2d8d513f
--- /dev/null
+++ b/website/src/hooks/tasks/evaluate/useRankReplies.ts
@@ -0,0 +1,25 @@
+import { useGenericTaskAPI } from "../useGenericTaskAPI";
+
+interface BaseRankRepliesTask {
+ id: string;
+ replies: string[];
+ conversation: {
+ messages: Array<{
+ text: string;
+ is_assistant: boolean;
+ message_id: string;
+ }>;
+ };
+}
+
+interface RankAssistantRepliesTask extends BaseRankRepliesTask {
+ type: "rank_assistant_replies";
+}
+
+interface RankPrompterRepliesTask extends BaseRankRepliesTask {
+ type: "rank_prompter_replies";
+}
+
+export const useRankAssistantRepliesTask = () => useGenericTaskAPI("rank_assistant_replies");
+
+export const useRankPrompterRepliesTask = () => useGenericTaskAPI("rank_prompter_replies");
diff --git a/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts
new file mode 100644
index 00000000..3c44046e
--- /dev/null
+++ b/website/src/hooks/tasks/labeling/useLabelAssistantReply.ts
@@ -0,0 +1,22 @@
+import { TaskResponse } from "../useGenericTaskAPI";
+import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
+
+export interface LabelAssistantReplyTask {
+ id: string;
+ type: LabelingTaskType.label_assistant_reply;
+ message_id: string;
+ valid_labels: string[];
+ reply: string;
+ conversation: {
+ messages: Array<{
+ text: string;
+ is_assistant: boolean;
+ message_id: string;
+ }>;
+ };
+}
+
+export type LabelAssistantReplyTaskResponse = TaskResponse;
+
+export const useLabelAssistantReplyTask = () =>
+ useLabelingTask(LabelingTaskType.label_assistant_reply);
diff --git a/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx
new file mode 100644
index 00000000..f7ba8ab5
--- /dev/null
+++ b/website/src/hooks/tasks/labeling/useLabelInitialPrompt.tsx
@@ -0,0 +1,15 @@
+import { TaskResponse } from "../useGenericTaskAPI";
+import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
+
+export interface LabelInitialPromptTask {
+ id: string;
+ type: LabelingTaskType.label_initial_prompt;
+ message_id: string;
+ valid_labels: string[];
+ prompt: string;
+}
+
+export type LabelInitialPromptTaskResponse = TaskResponse;
+
+export const useLabelInitialPromptTask = () =>
+ useLabelingTask(LabelingTaskType.label_initial_prompt);
diff --git a/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts
new file mode 100644
index 00000000..9de2057f
--- /dev/null
+++ b/website/src/hooks/tasks/labeling/useLabelPrompterReply.ts
@@ -0,0 +1,22 @@
+import { TaskResponse } from "../useGenericTaskAPI";
+import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
+
+export interface LabelPrompterReplyTask {
+ id: string;
+ type: LabelingTaskType.label_prompter_reply;
+ message_id: string;
+ valid_labels: string[];
+ reply: string;
+ conversation: {
+ messages: Array<{
+ text: string;
+ is_assistant: boolean;
+ message_id: string;
+ }>;
+ };
+}
+
+export type LabelPrompterReplyTaskResponse = TaskResponse;
+
+export const useLabelPrompterReplyTask = () =>
+ useLabelingTask(LabelingTaskType.label_prompter_reply);
diff --git a/website/src/hooks/tasks/labeling/useLabelingTask.ts b/website/src/hooks/tasks/labeling/useLabelingTask.ts
new file mode 100644
index 00000000..27555284
--- /dev/null
+++ b/website/src/hooks/tasks/labeling/useLabelingTask.ts
@@ -0,0 +1,20 @@
+import { useGenericTaskAPI } from "../useGenericTaskAPI";
+
+export const enum LabelingTaskType {
+ label_initial_prompt = "label_initial_prompt",
+ label_prompter_reply = "label_prompter_reply",
+ label_assistant_reply = "label_assistant_reply",
+}
+
+export const useLabelingTask = (endpoint: LabelingTaskType) => {
+ const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI(endpoint);
+
+ const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
+ console.assert(validLabels.length === labelWeights.length);
+ const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
+
+ return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
+ };
+
+ return { tasks, isLoading, submit, reset, error };
+};
diff --git a/website/src/hooks/tasks/useGenericTaskAPI.tsx b/website/src/hooks/tasks/useGenericTaskAPI.tsx
new file mode 100644
index 00000000..a57c9da4
--- /dev/null
+++ b/website/src/hooks/tasks/useGenericTaskAPI.tsx
@@ -0,0 +1,38 @@
+import { useState } from "react";
+import fetcher from "src/lib/fetcher";
+import poster from "src/lib/poster";
+import useSWRImmutable from "swr/immutable";
+import useSWRMutation from "swr/mutation";
+
+// TODO: type & centralize types for all tasks
+
+export interface TaskResponse {
+ id: string;
+ userId: string;
+ task: TaskType;
+}
+
+export const useGenericTaskAPI = (taskApiEndpoint: string) => {
+ type ConcreteTaskResponse = TaskResponse;
+
+ const [tasks, setTasks] = useState([]);
+
+ const { isLoading, mutate, error } = useSWRImmutable(
+ "/api/new_task/" + taskApiEndpoint,
+ fetcher,
+ {
+ onSuccess: (data) => setTasks([data]),
+ revalidateOnMount: true,
+ dedupingInterval: 500,
+ }
+ );
+
+ const { trigger } = useSWRMutation("/api/update_task", poster, {
+ onSuccess: async (response) => {
+ const newTask: ConcreteTaskResponse = await response.json();
+ setTasks((oldTasks) => [...oldTasks, newTask]);
+ },
+ });
+
+ return { tasks, isLoading, trigger, error, reset: mutate };
+};
diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts
new file mode 100644
index 00000000..5fa20f48
--- /dev/null
+++ b/website/src/lib/auth.ts
@@ -0,0 +1,19 @@
+import type { NextApiRequest, NextApiResponse } from "next";
+import { getToken } from "next-auth/jwt";
+
+/**
+ * Wraps any API Route handler and verifies that the user has the appropriate
+ * role before running the handler. Returns a 403 otherwise.
+ */
+const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
+ return async (req: NextApiRequest, res: NextApiResponse) => {
+ const token = await getToken({ req });
+ if (!token || token.role !== role) {
+ res.status(403).end();
+ return;
+ }
+ return handler(req, res);
+ };
+};
+
+export default withRole;
diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts
index 4cf891e1..889d8b5b 100644
--- a/website/src/lib/oasst_api_client.ts
+++ b/website/src/lib/oasst_api_client.ts
@@ -42,7 +42,7 @@ export class OasstApiClient {
} catch (e) {
throw new OasstError(errorText, 0, resp.status);
}
- throw new OasstError(error.message, error.error_code, resp.status);
+ throw new OasstError(error.message ?? error, error.error_code, resp.status);
}
return await resp.json();
@@ -68,6 +68,12 @@ export class OasstApiClient {
});
}
+ async nackTask(taskId: string, reason: string): Promise {
+ return this.post(`/api/v1/tasks/${taskId}/nack`, {
+ reason,
+ });
+ }
+
// TODO return a strongly typed Task?
// This method is used to record interaction with task while fetching next task.
// This is a raw Json type, so we can't use it to strongly type the task.
diff --git a/website/src/middleware.ts b/website/src/middleware.ts
index b6a539b4..d1cd6801 100644
--- a/website/src/middleware.ts
+++ b/website/src/middleware.ts
@@ -1,8 +1,8 @@
export { default } from "next-auth/middleware";
/**
- * Guards all pages under `/grading` and redirects them to the sign in page.
+ * Guards these pages and redirects them to the sign in page.
*/
export const config = {
- matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
+ matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
};
diff --git a/website/src/pages/_app.tsx b/website/src/pages/_app.tsx
index ab7655cd..69d212e8 100644
--- a/website/src/pages/_app.tsx
+++ b/website/src/pages/_app.tsx
@@ -3,7 +3,9 @@ import "focus-visible";
import type { AppProps } from "next/app";
import { SessionProvider } from "next-auth/react";
+import { FlagsProvider } from "react-feature-flags";
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
+import flags from "src/flags";
import { Chakra, getServerSideProps } from "../styles/Chakra";
@@ -16,9 +18,11 @@ function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: App
const page = getLayout();
return (
-
- {page}
-
+
+
+ {page}
+
+
);
}
export { getServerSideProps };
diff --git a/website/src/pages/admin/index.tsx b/website/src/pages/admin/index.tsx
index 60d61903..9cbea222 100644
--- a/website/src/pages/admin/index.tsx
+++ b/website/src/pages/admin/index.tsx
@@ -2,7 +2,7 @@ import Head from "next/head";
import { useRouter } from "next/router";
import { useSession } from "next-auth/react";
import { useEffect } from "react";
-import { getTransparentHeaderLayout } from "src/components/Layout";
+import { getAdminLayout } from "src/components/Layout";
import UsersCell from "src/components/UsersCell";
/**
@@ -26,10 +26,8 @@ const AdminIndex = () => {
return;
}
router.push("/");
- }, [session, status]);
+ }, [router, session, status]);
- // Show the final page.
- // TODO(#237): Display a component that fetches actual user data.
return (
<>
@@ -44,6 +42,6 @@ const AdminIndex = () => {
);
};
-AdminIndex.getLayout = getTransparentHeaderLayout;
+AdminIndex.getLayout = getAdminLayout;
export default AdminIndex;
diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx
new file mode 100644
index 00000000..cdd4746e
--- /dev/null
+++ b/website/src/pages/admin/manage_user/[id].tsx
@@ -0,0 +1,131 @@
+import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
+import { Field, Form, Formik } from "formik";
+import Head from "next/head";
+import { useRouter } from "next/router";
+import { useSession } from "next-auth/react";
+import { useEffect } from "react";
+import { getAdminLayout } from "src/components/Layout";
+import poster from "src/lib/poster";
+import prisma from "src/lib/prismadb";
+import useSWRMutation from "swr/mutation";
+
+const ManageUser = ({ user }) => {
+ const toast = useToast();
+ const router = useRouter();
+ const { data: session, status } = useSession();
+
+ // Check when the user session is loaded and re-route if the user is not an
+ // admin. This follows the suggestion by NextJS for handling private pages:
+ // https://nextjs.org/docs/api-reference/next/router#usage
+ //
+ // All admin pages should use the same check and routing steps.
+ useEffect(() => {
+ if (status === "loading") {
+ return;
+ }
+ if (session?.user?.role === "admin") {
+ return;
+ }
+ router.push("/");
+ }, [router, session, status]);
+
+ // Trigger to let us update the user's role. Triggers a toast when complete.
+ const { trigger } = useSWRMutation("/api/admin/update_user", poster, {
+ onSuccess: () => {
+ toast({
+ title: "User Role Updated",
+ status: "success",
+ duration: 1000,
+ isClosable: true,
+ });
+ },
+ onError: () => {
+ toast({
+ title: "User Role update failed",
+ status: "error",
+ duration: 1000,
+ isClosable: true,
+ });
+ },
+ });
+
+ return (
+ <>
+
+ Manage Users - Open Assistant
+
+
+
+ {
+ trigger(values);
+ }}
+ >
+
+
+
+ >
+ );
+};
+
+/**
+ * Fetch the user's data on the server side when rendering.
+ */
+export async function getServerSideProps({ query }) {
+ const user = await prisma.user.findUnique({
+ where: { id: query.id },
+ select: {
+ id: true,
+ name: true,
+ email: true,
+ role: true,
+ },
+ });
+ return {
+ props: {
+ user,
+ },
+ };
+}
+
+ManageUser.getLayout = getAdminLayout;
+
+export default ManageUser;
diff --git a/website/src/pages/api/admin/update_user.ts b/website/src/pages/api/admin/update_user.ts
new file mode 100644
index 00000000..a717e3d8
--- /dev/null
+++ b/website/src/pages/api/admin/update_user.ts
@@ -0,0 +1,22 @@
+import withRole from "src/lib/auth";
+import prisma from "src/lib/prismadb";
+
+/**
+ * Update's the user's data in the database. Accessible only to admins.
+ */
+const handler = withRole("admin", async (req, res) => {
+ const { id, role } = JSON.parse(req.body);
+
+ await prisma.user.update({
+ where: {
+ id,
+ },
+ data: {
+ role,
+ },
+ });
+
+ res.status(200).end();
+});
+
+export default handler;
diff --git a/website/src/pages/api/admin/users.ts b/website/src/pages/api/admin/users.ts
index 186bb253..ea8d59d9 100644
--- a/website/src/pages/api/admin/users.ts
+++ b/website/src/pages/api/admin/users.ts
@@ -1,31 +1,34 @@
-import { getToken } from "next-auth/jwt";
-import client from "src/lib/prismadb";
+import withRole from "src/lib/auth";
+import prisma from "src/lib/prismadb";
+
+// The number of users to fetch in any request.
+const PAGE_SIZE = 20;
/**
* Returns a list of user results from the database when the requesting user is
* a logged in admin.
*/
-const handler = async (req, res) => {
- const token = await getToken({ req });
-
- // Return nothing if the user isn't registered or if the user isn't an admin.
- if (!token || token.role !== "admin") {
- res.status(403).end();
- return;
- }
+const handler = withRole("admin", async (req, res) => {
+ // Figure out the pagination index and skip that number of users.
+ //
+ // Note: with Prisma this isn't the most efficient but it's the only possible
+ // option with cuid based User IDs.
+ const { pageIndex } = req.query;
+ const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0;
// Fetch 20 users.
- const users = await client.user.findMany({
+ const users = await prisma.user.findMany({
select: {
id: true,
role: true,
name: true,
email: true,
},
- take: 20,
+ skip,
+ take: PAGE_SIZE,
});
res.status(200).json(users);
-};
+});
export default handler;
diff --git a/website/src/pages/api/new_task/[task_type].ts b/website/src/pages/api/new_task/[task_type].ts
index addcf3d8..9f3be55c 100644
--- a/website/src/pages/api/new_task/[task_type].ts
+++ b/website/src/pages/api/new_task/[task_type].ts
@@ -36,9 +36,6 @@ const handler = async (req, res) => {
},
});
- // Update the backend with our Task ID
- await oasstApiClient.ackTask(task.id, registeredTask.id);
-
// Send the results to the client.
res.status(200).json(registeredTask);
};
diff --git a/website/src/pages/api/reject_task.ts b/website/src/pages/api/reject_task.ts
new file mode 100644
index 00000000..d146c44b
--- /dev/null
+++ b/website/src/pages/api/reject_task.ts
@@ -0,0 +1,29 @@
+import { Prisma } from "@prisma/client";
+import { getToken } from "next-auth/jwt";
+import { oasstApiClient } from "src/lib/oasst_api_client";
+
+const handler = async (req, res) => {
+ const token = await getToken({ req });
+
+ // Return nothing if the user isn't registered.
+ if (!token) {
+ res.status(401).end();
+ return;
+ }
+
+ // Parse out the local task ID and the interaction contents.
+ const { id: frontendId, reason } = await JSON.parse(req.body);
+
+ const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
+
+ const task = registeredTask.task as Prisma.JsonObject;
+ const id = task.id as string;
+
+ // Update the backend with the rejection
+ await oasstApiClient.nackTask(id, reason);
+
+ // Send the results to the client.
+ res.status(200).json({});
+};
+
+export default handler;
diff --git a/website/src/pages/api/set_label.ts b/website/src/pages/api/set_label.ts
new file mode 100644
index 00000000..4db5ddaf
--- /dev/null
+++ b/website/src/pages/api/set_label.ts
@@ -0,0 +1,41 @@
+import { getToken } from "next-auth/jwt";
+import prisma from "src/lib/prismadb";
+
+/**
+ * Sets the Label in the Backend.
+ *
+ */
+const handler = async (req, res) => {
+ const token = await getToken({ req });
+
+ // Return nothing if the user isn't registered.
+ if (!token) {
+ res.status(401).end();
+ return;
+ }
+
+ // Parse out the local message_id, task ID and the interaction contents.
+ const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
+
+ const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
+ method: "POST",
+ headers: {
+ "X-API-Key": process.env.FASTAPI_KEY,
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify({
+ type: "text_labels",
+ message_id: message_id,
+ labels: label_map,
+ text: text,
+ user: {
+ id: token.sub,
+ display_name: token.name || token.email,
+ auth_method: "local",
+ },
+ }),
+ });
+ res.status(interactionRes.status).end();
+};
+
+export default handler;
diff --git a/website/src/pages/api/update_task.ts b/website/src/pages/api/update_task.ts
index 4eea8c1e..e8e21ca9 100644
--- a/website/src/pages/api/update_task.ts
+++ b/website/src/pages/api/update_task.ts
@@ -1,3 +1,4 @@
+import { Prisma } from "@prisma/client";
import { getToken } from "next-auth/jwt";
import { oasstApiClient } from "src/lib/oasst_api_client";
import prisma from "src/lib/prismadb";
@@ -6,9 +7,11 @@ import prisma from "src/lib/prismadb";
* Stores the task interaction with the Task Backend and then returns the next task generated.
*
* This implicity does a few things:
- * 1) Stores the answer with the Task Backend.
- * 2) Records the new task in our local database.
- * 3) Returns the newly created task to the client.
+ * 1) Records the users answer in our local database.
+ * 2) Accepts the task.
+ * 3) Sends the users answer to the Task Backend.
+ * 4) Records the new task in our local database.
+ * 5) Returns the newly created task to the client.
*/
const handler = async (req, res) => {
const token = await getToken({ req });
@@ -20,7 +23,13 @@ const handler = async (req, res) => {
}
// Parse out the local task ID and the interaction contents.
- const { id, content, update_type } = await JSON.parse(req.body);
+ const { id: frontendId, content, update_type } = await JSON.parse(req.body);
+
+ // Accept the task so that we can complete it, this will probably go away soon.
+ const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
+ const task = registeredTask.task as Prisma.JsonObject;
+ const id = task.id as string;
+ await oasstApiClient.ackTask(id, registeredTask.id);
// Log the interaction locally to create our user_post_id needed by the Task
// Backend.
@@ -29,13 +38,18 @@ const handler = async (req, res) => {
content,
task: {
connect: {
- id,
+ id: frontendId,
},
},
},
});
- const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
+ let newTask;
+ try {
+ newTask = await oasstApiClient.interactTask(update_type, frontendId, interaction.id, content, token);
+ } catch (err) {
+ return res.status(500).json(err);
+ }
// Stores the new task with our database.
const newRegisteredTask = await prisma.registeredTask.create({
diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx
index 59fc7c05..9a1d91a8 100644
--- a/website/src/pages/auth/signin.tsx
+++ b/website/src/pages/auth/signin.tsx
@@ -2,17 +2,59 @@ import { Button, Input, Stack } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import Link from "next/link";
+import { useRouter } from "next/router";
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
-import React, { useRef } from "react";
+import React, { useEffect, useRef, useState } from "react";
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
import { AuthLayout } from "src/components/AuthLayout";
import { Footer } from "src/components/Footer";
import { Header } from "src/components/Header";
+export type SignInErrorTypes =
+ | "Signin"
+ | "OAuthSignin"
+ | "OAuthCallback"
+ | "OAuthCreateAccount"
+ | "EmailCreateAccount"
+ | "Callback"
+ | "OAuthAccountNotLinked"
+ | "EmailSignin"
+ | "CredentialsSignin"
+ | "SessionRequired"
+ | "default";
+
+const errorMessages: Record = {
+ Signin: "Try signing in with a different account.",
+ OAuthSignin: "Try signing in with a different account.",
+ OAuthCallback: "Try signing in with the same account you used originally.",
+ OAuthCreateAccount: "Try signing in with a different account.",
+ EmailCreateAccount: "Try signing in with a different account.",
+ Callback: "Try signing in with a different account.",
+ OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.",
+ EmailSignin: "The e-mail could not be sent.",
+ CredentialsSignin: "Sign in failed. Check the details you provided are correct.",
+ SessionRequired: "Please sign in to access this page.",
+ default: "Unable to sign in.",
+};
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
function Signin({ csrfToken, providers }) {
+ const router = useRouter();
const { discord, email, github, credentials } = providers;
const emailEl = useRef(null);
+ const [error, setError] = useState("");
+
+ useEffect(() => {
+ const err = router?.query?.error;
+ if (err) {
+ if (typeof err === "string") {
+ setError(errorMessages[err]);
+ } else {
+ setError(errorMessages[err[0]]);
+ }
+ }
+ }, [router]);
+
const signinWithEmail = (ev: React.FormEvent) => {
ev.preventDefault();
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
@@ -110,6 +152,11 @@ function Signin({ csrfToken, providers }) {
.
+ {error && (
+
+ )}
);
diff --git a/website/src/pages/auth/verify.tsx b/website/src/pages/auth/verify.tsx
index e004f504..b4d7d739 100644
--- a/website/src/pages/auth/verify.tsx
+++ b/website/src/pages/auth/verify.tsx
@@ -1,17 +1,23 @@
+import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
import { getCsrfToken, getProviders } from "next-auth/react";
import { AuthLayout } from "src/components/AuthLayout";
export default function Verify() {
+ const { colorMode } = useColorMode();
+ const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
+
return (
<>
Sign Up - Open Assistant
-
- A sign-in link has been sent to your email address.
-
+
+
+
A sign-in link has been sent to your email address.
+
+
>
);
}
diff --git a/website/src/pages/create/assistant_reply.tsx b/website/src/pages/create/assistant_reply.tsx
index 6dad3931..e9aee226 100644
--- a/website/src/pages/create/assistant_reply.tsx
+++ b/website/src/pages/create/assistant_reply.tsx
@@ -1,35 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
const AssistantReply = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
+ const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -38,7 +15,7 @@ const AssistantReply = () => {
return ;
}
- if (tasks.length == 0) {
+ if (tasks.length === 0) {
return No tasks found...;
}
@@ -48,7 +25,7 @@ const AssistantReply = () => {
Reply as Assistant
-
+
>
);
};
diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx
index fc9ba39b..efea6474 100644
--- a/website/src/pages/create/initial_prompt.tsx
+++ b/website/src/pages/create/initial_prompt.tsx
@@ -1,35 +1,12 @@
import { Container } from "@chakra-ui/react";
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
const InitialPrompt = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/initial_prompt ", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
+ const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -38,7 +15,7 @@ const InitialPrompt = () => {
return ;
}
- if (tasks.length == 0) {
+ if (tasks.length === 0) {
return No tasks found...;
}
@@ -48,7 +25,7 @@ const InitialPrompt = () => {
Reply as Assistant
-
+
>
);
};
diff --git a/website/src/pages/create/summarize_story.tsx b/website/src/pages/create/summarize_story.tsx
index 1c5b89b9..8620a8f5 100644
--- a/website/src/pages/create/summarize_story.tsx
+++ b/website/src/pages/create/summarize_story.tsx
@@ -87,7 +87,12 @@ const SummarizeStory = () => {
>
-
+
);
diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx
index 70487805..2394bd63 100644
--- a/website/src/pages/create/user_reply.tsx
+++ b/website/src/pages/create/user_reply.tsx
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
+import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
const UserReply = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/prompter_reply", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
+ const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const UserReply = () => {
return ;
}
- if (tasks.length == 0) {
- return (
-
- );
+ if (tasks.length === 0) {
+ return No tasks found...;
}
return (
@@ -53,7 +25,7 @@ const UserReply = () => {
Reply as Assistant
-
+
>
);
};
diff --git a/website/src/pages/evaluate/rank_assistant_replies.tsx b/website/src/pages/evaluate/rank_assistant_replies.tsx
index ccf6d55b..931c9194 100644
--- a/website/src/pages/evaluate/rank_assistant_replies.tsx
+++ b/website/src/pages/evaluate/rank_assistant_replies.tsx
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
+import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankAssistantReplies = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_assistant_replies", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
+ const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankAssistantReplies = () => {
return ;
}
- if (tasks.length == 0) {
- return (
-
- );
+ if (tasks.length === 0) {
+ return No tasks found...;
}
return (
@@ -53,7 +25,7 @@ const RankAssistantReplies = () => {
Rank Assistant Replies
-
+
>
);
};
diff --git a/website/src/pages/evaluate/rank_initial_prompts.tsx b/website/src/pages/evaluate/rank_initial_prompts.tsx
index e7c69573..4b717143 100644
--- a/website/src/pages/evaluate/rank_initial_prompts.tsx
+++ b/website/src/pages/evaluate/rank_initial_prompts.tsx
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
+import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
const RankInitialPrompts = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
+ const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankInitialPrompts = () => {
return ;
}
- if (tasks.length == 0) {
- return (
-
- );
+ if (tasks.length === 0) {
+ return No tasks found...;
}
return (
@@ -53,7 +25,7 @@ const RankInitialPrompts = () => {
Rank Initial Prompts
-
+
>
);
};
diff --git a/website/src/pages/evaluate/rank_user_replies.tsx b/website/src/pages/evaluate/rank_user_replies.tsx
index 13e8923a..659874a2 100644
--- a/website/src/pages/evaluate/rank_user_replies.tsx
+++ b/website/src/pages/evaluate/rank_user_replies.tsx
@@ -1,34 +1,12 @@
import { useColorMode } from "@chakra-ui/react";
import Head from "next/head";
-import { useEffect, useState } from "react";
+import { Container } from "src/components/Container";
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
import { Task } from "src/components/Tasks/Task";
-import fetcher from "src/lib/fetcher";
-import poster from "src/lib/poster";
-import useSWRImmutable from "swr/immutable";
-import useSWRMutation from "swr/mutation";
+import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
const RankUserReplies = () => {
- const [tasks, setTasks] = useState([]);
-
- const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_prompter_replies", fetcher, {
- onSuccess: (data) => {
- setTasks([data]);
- },
- });
-
- useEffect(() => {
- if (tasks.length == 0) {
- mutate();
- }
- }, [tasks]);
-
- const { trigger } = useSWRMutation("/api/update_task", poster, {
- onSuccess: async (data) => {
- const newTask = await data.json();
- setTasks((oldTasks) => [...oldTasks, newTask]);
- },
- });
+ const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
const { colorMode } = useColorMode();
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
@@ -37,14 +15,8 @@ const RankUserReplies = () => {
return ;
}
- if (tasks.length == 0) {
- return (
-
- );
+ if (tasks.length === 0) {
+ return No tasks found...;
}
return (
@@ -53,7 +25,7 @@ const RankUserReplies = () => {
Rank User Replies
-
+
>
);
};
diff --git a/website/src/pages/evaluate/rate_summary.tsx b/website/src/pages/evaluate/rate_summary.tsx
index e0118ece..0d2352a2 100644
--- a/website/src/pages/evaluate/rate_summary.tsx
+++ b/website/src/pages/evaluate/rate_summary.tsx
@@ -102,7 +102,12 @@ const RateSummary = () => {
-
+
>
);
diff --git a/website/src/pages/label/label_assistant_reply.tsx b/website/src/pages/label/label_assistant_reply.tsx
new file mode 100644
index 00000000..a0f961f7
--- /dev/null
+++ b/website/src/pages/label/label_assistant_reply.tsx
@@ -0,0 +1,47 @@
+import { useState } from "react";
+import { LoadingScreen } from "src/components/Loading/LoadingScreen";
+import { Message } from "src/components/Messages";
+import { MessageTable } from "src/components/Messages/MessageTable";
+import { TaskControls } from "src/components/Survey/TaskControls";
+import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
+import {
+ LabelAssistantReplyTaskResponse,
+ useLabelAssistantReplyTask,
+} from "src/hooks/tasks/labeling/useLabelAssistantReply";
+
+const LabelAssistantReply = () => {
+ const [sliderValues, setSliderValues] = useState([]);
+
+ const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
+
+ if (isLoading || tasks.length === 0) {
+ return ;
+ }
+
+ const task = tasks[0].task;
+ const messages: Message[] = [
+ ...task.conversation.messages,
+ { text: task.reply, is_assistant: true, message_id: task.message_id },
+ ];
+
+ return (
+ }
+ inputs={}
+ controls={
+ reset()}
+ onNextTask={reset}
+ onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
+ submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
+ }
+ />
+ }
+ />
+ );
+};
+
+export default LabelAssistantReply;
diff --git a/website/src/pages/label/label_initial_prompt.tsx b/website/src/pages/label/label_initial_prompt.tsx
new file mode 100644
index 00000000..3c791f23
--- /dev/null
+++ b/website/src/pages/label/label_initial_prompt.tsx
@@ -0,0 +1,42 @@
+import { useState } from "react";
+import { LoadingScreen } from "src/components/Loading/LoadingScreen";
+import { MessageView } from "src/components/Messages";
+import { TaskControls } from "src/components/Survey/TaskControls";
+import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
+import {
+ LabelInitialPromptTaskResponse,
+ useLabelInitialPromptTask,
+} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
+
+const LabelInitialPrompt = () => {
+ const [sliderValues, setSliderValues] = useState([]);
+
+ const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
+
+ if (isLoading || tasks.length === 0) {
+ return ;
+ }
+
+ const task = tasks[0].task;
+
+ return (
+ }
+ inputs={}
+ controls={
+ reset()}
+ onNextTask={reset}
+ onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
+ submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
+ }
+ />
+ }
+ />
+ );
+};
+
+export default LabelInitialPrompt;
diff --git a/website/src/pages/label/label_prompter_reply.tsx b/website/src/pages/label/label_prompter_reply.tsx
new file mode 100644
index 00000000..2fd3d76a
--- /dev/null
+++ b/website/src/pages/label/label_prompter_reply.tsx
@@ -0,0 +1,47 @@
+import { useState } from "react";
+import { LoadingScreen } from "src/components/Loading/LoadingScreen";
+import { Message } from "src/components/Messages";
+import { MessageTable } from "src/components/Messages/MessageTable";
+import { TaskControls } from "src/components/Survey/TaskControls";
+import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
+import {
+ LabelPrompterReplyTaskResponse,
+ useLabelPrompterReplyTask,
+} from "src/hooks/tasks/labeling/useLabelPrompterReply";
+
+const LabelPrompterReply = () => {
+ const [sliderValues, setSliderValues] = useState([]);
+
+ const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
+
+ if (isLoading || tasks.length === 0) {
+ return ;
+ }
+
+ const task = tasks[0].task;
+ const messages: Message[] = [
+ ...task.conversation.messages,
+ { text: task.reply, is_assistant: false, message_id: task.message_id },
+ ];
+
+ return (
+ }
+ inputs={}
+ controls={
+ reset()}
+ onNextTask={reset}
+ onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
+ submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
+ }
+ />
+ }
+ />
+ );
+};
+
+export default LabelPrompterReply;
diff --git a/website/src/pages/messages/index.tsx b/website/src/pages/messages/index.tsx
index c32a90c6..2809ba5c 100644
--- a/website/src/pages/messages/index.tsx
+++ b/website/src/pages/messages/index.tsx
@@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
import Head from "next/head";
import { useEffect, useState } from "react";
import { getDashboardLayout } from "src/components/Layout";
+import { Message } from "src/components/Messages";
import { MessageTable } from "src/components/Messages/MessageTable";
import fetcher from "src/lib/fetcher";
import useSWRImmutable from "swr/immutable";
@@ -10,29 +11,28 @@ const MessagesDashboard = () => {
const boxBgColor = useColorModeValue("white", "gray.700");
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
- const [messages, setMessages] = useState([]);
- const [userMessages, setUserMessages] = useState([]);
+ const [messages, setMessages] = useState(null);
+ const [userMessages, setUserMessages] = useState(null);
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
- onSuccess: (data) => {
- setMessages(data);
- },
+ onSuccess: setMessages,
});
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
- onSuccess: (data) => {
- setUserMessages(data);
- },
+ onSuccess: setUserMessages,
});
+ const receivedMessages = !isLoadingAll && Array.isArray(messages);
+ const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages);
+
useEffect(() => {
- if (messages.length == 0) {
+ if (!receivedMessages) {
mutateAll();
}
- if (userMessages.length == 0) {
+ if (!receivedUserMessages) {
mutateUser();
}
- }, [messages, userMessages]);
+ }, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]);
return (
<>
@@ -52,7 +52,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
- {isLoadingAll ? : }
+ {receivedMessages ? : }
@@ -66,7 +66,7 @@ const MessagesDashboard = () => {
borderRadius="xl"
className="p-6 shadow-sm"
>
- {isLoadingUser ? : }
+ {receivedUserMessages ? : }
diff --git a/website/src/pages/tasks/all.tsx b/website/src/pages/tasks/all.tsx
new file mode 100644
index 00000000..6e4e926b
--- /dev/null
+++ b/website/src/pages/tasks/all.tsx
@@ -0,0 +1,19 @@
+import Head from "next/head";
+import { TaskOption } from "src/components/Dashboard";
+import { getDashboardLayout } from "src/components/Layout";
+
+const AllTasks = () => {
+ return (
+ <>
+
+ All Tasks - Open Assistant
+
+
+
+ >
+ );
+};
+
+AllTasks.getLayout = (page) => getDashboardLayout(page);
+
+export default AllTasks;