mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -8,6 +8,9 @@ on:
|
||||
- ".github/workflows/deploy-docs-site.yaml"
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/deploy-docs-site.yaml"
|
||||
- "docs/**"
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
@@ -16,3 +16,12 @@ jobs:
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- name: Post PR comment on failure
|
||||
if: failure() && github.event_name == 'pull_request'
|
||||
uses: peter-evans/create-or-update-comment@v2
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
body: |
|
||||
:x: **pre-commit** failed.
|
||||
Please run `pre-commit run --all-files` locally and commit the changes.
|
||||
Find more information in the repository's CONTRIBUTING.md
|
||||
|
||||
@@ -29,6 +29,15 @@ jobs:
|
||||
deploy-dev:
|
||||
needs: [build-backend, build-web, build-bot]
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
WEB_ADMIN_USERS: ${{ secrets.DEV_WEB_ADMIN_USERS }}
|
||||
WEB_DISCORD_CLIENT_ID: ${{ secrets.DEV_WEB_DISCORD_CLIENT_ID }}
|
||||
WEB_DISCORD_CLIENT_SECRET: ${{ secrets.DEV_WEB_DISCORD_CLIENT_SECRET }}
|
||||
WEB_EMAIL_SERVER_HOST: ${{ secrets.DEV_WEB_EMAIL_SERVER_HOST }}
|
||||
WEB_EMAIL_SERVER_PASSWORD: ${{ secrets.DEV_WEB_EMAIL_SERVER_PASSWORD }}
|
||||
WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
|
||||
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
|
||||
WEB_NEXTAUTH_SECRET: ${{ secrets.DEV_WEB_NEXTAUTH_SECRET }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
+4
-2
@@ -96,8 +96,10 @@ The website is built using Next.js and is in the `website` folder.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
Install `pre-commit` and run `pre-commit install` to install the pre-commit
|
||||
hooks.
|
||||
We are using `pre-commit` to enforce code style and formatting.
|
||||
|
||||
Install `pre-commit` from [its website](https://pre-commit.com) and run
|
||||
`pre-commit install` to install the pre-commit hooks.
|
||||
|
||||
In case you haven't done this, have already committed, and CI is failing, you
|
||||
can run `pre-commit run --all-files` to run the pre-commit hooks on all files.
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
To test the ansible playbook on localhost run
|
||||
`ansible-playbook -i test.inventory.ini dev.yaml`.\
|
||||
In case you're missing the ansible docker depencency install it with `ansible-galaxy collection install community.docker`.\
|
||||
Point Redis Insights to the Redis database by visiting localhost:8001 in a
|
||||
browser and select "I already have a database" followed by "Connect to a Redis
|
||||
Database".\
|
||||
For host, port and name fill in `oasst-redis`, `6379` and `redis`.
|
||||
+53
-15
@@ -10,6 +10,39 @@
|
||||
state: present
|
||||
driver: bridge
|
||||
|
||||
- name: Copy redis.conf to managed node
|
||||
ansible.builtin.copy:
|
||||
src: ./redis.conf
|
||||
dest: ./redis.conf
|
||||
|
||||
- name: Set up Redis
|
||||
community.docker.docker_container:
|
||||
name: oasst-redis
|
||||
image: redis
|
||||
state: started
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
ports:
|
||||
- 6379:6379
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
|
||||
interval: 2s
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
command: redis-server /usr/local/etc/redis/redis.conf
|
||||
volumes:
|
||||
- "./redis.conf:/usr/local/etc/redis/redis.conf"
|
||||
|
||||
- name: Set up Redis Insights
|
||||
community.docker.docker_container:
|
||||
name: oasst-redis-insights
|
||||
image: redislabs/redisinsight:latest
|
||||
state: started
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
ports:
|
||||
- 8001:8001
|
||||
|
||||
- name: Create postgres containers
|
||||
community.docker.docker_container:
|
||||
name: "{{ item.name }}"
|
||||
@@ -32,14 +65,6 @@
|
||||
- name: oasst-postgres
|
||||
- name: oasst-postgres-web
|
||||
|
||||
- name: Set up maildev
|
||||
community.docker.docker_container:
|
||||
name: oasst-maildev
|
||||
image: maildev/maildev
|
||||
state: started
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
|
||||
- name: Run the oasst oasst-backend
|
||||
community.docker.docker_container:
|
||||
name: oasst-backend
|
||||
@@ -51,6 +76,7 @@
|
||||
network_mode: oasst
|
||||
env:
|
||||
POSTGRES_HOST: oasst-postgres
|
||||
REDIS_HOST: oasst-redis
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
MAX_WORKERS: "1"
|
||||
@@ -68,15 +94,27 @@
|
||||
restart_policy: always
|
||||
network_mode: oasst
|
||||
env:
|
||||
FASTAPI_URL: http://oasst-backend:8080
|
||||
FASTAPI_KEY: "123"
|
||||
ADMIN_USERS: "{{ lookup('ansible.builtin.env', 'WEB_ADMIN_USERS') }}"
|
||||
DATABASE_URL: postgres://postgres:postgres@oasst-postgres-web/postgres
|
||||
NEXTAUTH_SECRET: O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=
|
||||
EMAIL_SERVER_HOST: oasst-maildev
|
||||
EMAIL_SERVER_PORT: "25"
|
||||
EMAIL_FROM: info@example.com
|
||||
NEXTAUTH_URL: http://web.dev.open-assistant.io/
|
||||
DEBUG_LOGIN: "true"
|
||||
DISCORD_CLIENT_ID:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_ID') }}"
|
||||
DISCORD_CLIENT_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_SECRET') }}"
|
||||
EMAIL_FROM: open-assistent@laion.ai
|
||||
EMAIL_SERVER_HOST:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_HOST') }}"
|
||||
EMAIL_SERVER_PASSWORD:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PASSWORD') }}"
|
||||
EMAIL_SERVER_PORT:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_PORT') }}"
|
||||
EMAIL_SERVER_USER:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
|
||||
FASTAPI_URL: http://oasst-backend:8080
|
||||
FASTAPI_KEY: "1234"
|
||||
NEXTAUTH_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
|
||||
NEXTAUTH_URL: http://web.dev.open-assistant.io/
|
||||
ports:
|
||||
- 3000:3000
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
maxmemory 100mb
|
||||
maxmemory-policy allkeys-lru
|
||||
@@ -0,0 +1,2 @@
|
||||
[test]
|
||||
dev ansible_connection=local
|
||||
+1
-1
@@ -20,4 +20,4 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("api_client", "frontend_id")
|
||||
op.drop_column("api_client", "frontend_type")
|
||||
|
||||
+11
-6
@@ -14,7 +14,7 @@ from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -110,7 +110,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
@@ -118,14 +123,14 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
@@ -144,12 +149,12 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
for cmsg in conversation_messages
|
||||
]
|
||||
)
|
||||
task = pr.store_task(
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_message_by_frontend_id(
|
||||
"""
|
||||
Get a message by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -29,7 +29,7 @@ def get_conv_by_frontend_id(
|
||||
Get a conversation from the tree root and up to the message with given frontend ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_conversation(message)
|
||||
return utils.prepare_conversation(messages)
|
||||
@@ -43,7 +43,7 @@ def get_tree_by_frontend_id(
|
||||
Get all messages belonging to the same message tree.
|
||||
Message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -56,7 +56,7 @@ def get_children_by_frontend_id(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return utils.prepare_message_list(messages)
|
||||
@@ -70,7 +70,7 @@ def get_descendants_by_frontend_id(
|
||||
Get a subtree which starts with this message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -84,7 +84,7 @@ def get_longest_conv_by_frontend_id(
|
||||
Get the longest conversation from the tree of the message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -98,7 +98,7 @@ def get_max_children_by_frontend_id(
|
||||
Get message with the most children from the tree of the provided message.
|
||||
The message is identified by its frontend ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_frontend_user_messages(
|
||||
"""
|
||||
Query frontend user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -47,6 +47,6 @@ def query_frontend_user_messages(
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
@@ -11,15 +12,15 @@ router = APIRouter()
|
||||
def get_assistant_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="assistant")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="assistant")
|
||||
|
||||
|
||||
@router.get("/create/prompter")
|
||||
def get_prompter_leaderboard(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
return pr.get_user_leaderboard(role="prompter")
|
||||
) -> LeaderboardStats:
|
||||
ur = UserRepository(db, api_client)
|
||||
return ur.get_user_leaderboard(role="prompter")
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
api_client_id=api_client_id,
|
||||
@@ -51,7 +51,7 @@ def get_message(
|
||||
"""
|
||||
Get a message by its internal ID.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
return utils.prepare_message(message)
|
||||
|
||||
@@ -64,7 +64,7 @@ def get_conv(
|
||||
Get a conversation from the tree root and up to the message with given internal ID.
|
||||
"""
|
||||
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_conversation(message_id)
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
@@ -76,7 +76,7 @@ def get_tree(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
@@ -89,7 +89,7 @@ def get_children(
|
||||
"""
|
||||
Get all messages belonging to the same message tree.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
@@ -101,7 +101,7 @@ def get_descendants(
|
||||
"""
|
||||
Get a subtree which starts with this message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
descendants = pr.fetch_message_descendants(message)
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
@@ -114,7 +114,7 @@ def get_longest_conv(
|
||||
"""
|
||||
Get the longest conversation from the tree of the message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
conv = pr.fetch_longest_conversation(message.message_tree_id)
|
||||
return utils.prepare_conversation(conv)
|
||||
@@ -127,7 +127,7 @@ def get_max_children(
|
||||
"""
|
||||
Get message with the most children from the tree of the provided message.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
message, children = pr.fetch_message_with_max_children(message.message_tree_id)
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
@@ -137,5 +137,5 @@ def get_max_children(
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
|
||||
@@ -13,5 +13,5 @@ def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
return pr.get_stats()
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -190,9 +190,9 @@ def request_task(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -217,11 +217,11 @@ def tasks_acknowledge(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
pr.task_repository.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -245,8 +245,8 @@ def tasks_acknowledge_failure(
|
||||
try:
|
||||
logger.info(f"Frontend reports failure to implement task {task_id=}, {nack_request=}.")
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.acknowledge_task_failure(task_id)
|
||||
pr = PromptRepository(db, api_client)
|
||||
pr.task_repository.acknowledge_task_failure(task_id)
|
||||
except (KeyError, RuntimeError):
|
||||
logger.exception("Failed to not acknowledge task.")
|
||||
raise OasstError("Failed to not acknowledge task.", OasstErrorCode.TASK_NACK_FAILED)
|
||||
@@ -265,7 +265,7 @@ def tasks_interaction(
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
pr = PromptRepository(db, api_client, client_user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
@@ -323,6 +323,6 @@ def close_collective_task(
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.message_id)
|
||||
tr = TaskRepository(db, api_client)
|
||||
tr.close_task(close_task_request.message_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -25,7 +25,7 @@ def label_text(
|
||||
|
||||
try:
|
||||
logger.info(f"Labeling text {text_labels=}.")
|
||||
pr = PromptRepository(db, api_client, user=text_labels.user)
|
||||
pr = PromptRepository(db, api_client, client_user=text_labels.user)
|
||||
pr.store_text_labels(text_labels)
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -29,7 +29,7 @@ def query_user_messages(
|
||||
"""
|
||||
Query user messages.
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
user_id=user_id,
|
||||
api_client_id=api_client_id,
|
||||
@@ -48,6 +48,6 @@ def query_user_messages(
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
|
||||
@@ -6,27 +6,56 @@ import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
# The types of States a message tree can have.
|
||||
|
||||
class States(str, Enum):
|
||||
"""States of the Open-Assistant message tree state machine."""
|
||||
|
||||
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
||||
"""In this state the message tree consists only of a single inital prompt root node.
|
||||
Initial prompt labeling tasks will determine if the tree goes into `breeding_phase` or
|
||||
`aborted_low_grade`."""
|
||||
|
||||
class States(Enum):
|
||||
INITIAL = "initial"
|
||||
BREEDING_PHASE = "breeding_phase"
|
||||
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
|
||||
are handed out to check if the quality of the replies surpasses the minimum acceptable
|
||||
quality.
|
||||
When the required number of messages passing the initial labelling-quality check has been
|
||||
collected the tree will enter `ranking_phase`. If too many poor-quality labelling responses
|
||||
are received the tree can also enter the `aborted_low_grade` state."""
|
||||
|
||||
RANKING_PHASE = "ranking_phase"
|
||||
"""The tree has been successfully populated with the desired number of messages. Ranking
|
||||
tasks are now handed out for all nodes with more than one child."""
|
||||
|
||||
READY_FOR_SCORING = "ready_for_scoring"
|
||||
CHILDREN_SCORED = "children_scored"
|
||||
FINAL = "final"
|
||||
"""Required ranking responses have been collected and the scoring algorithm can now
|
||||
compute the aggergated ranking scores that will appear in the dataset."""
|
||||
|
||||
READY_FOR_EXPORT = "ready_for_export"
|
||||
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
|
||||
exported as part of an Open-Assistant message tree dataset."""
|
||||
|
||||
SCORING_FAILED = "scoring_failed"
|
||||
"""An exception occured in the scoring algorithm."""
|
||||
|
||||
ABORTED_LOW_GRADE = "aborted_low_grade"
|
||||
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
|
||||
|
||||
HALTED_BY_MODERATOR = "halted_by_moderator"
|
||||
"""A moderator decided to manually halt the message tree construction process."""
|
||||
|
||||
|
||||
VALID_STATES = (
|
||||
States.INITIAL,
|
||||
States.INITIAL_PROMPT_REVIEW,
|
||||
States.BREEDING_PHASE,
|
||||
States.RANKING_PHASE,
|
||||
States.READY_FOR_SCORING,
|
||||
States.CHILDREN_SCORED,
|
||||
States.FINAL,
|
||||
States.READY_FOR_EXPORT,
|
||||
States.ABORTED_LOW_GRADE,
|
||||
)
|
||||
|
||||
TERMINAL_STATES = (States.READY_FOR_EXPORT, States.ABORTED_LOW_GRADE, States.SCORING_FAILED, States.HALTED_BY_MODERATOR)
|
||||
|
||||
|
||||
class MessageTreeState(SQLModel, table=True):
|
||||
__tablename__ = "message_tree_state"
|
||||
|
||||
@@ -35,4 +35,4 @@ class Task(SQLModel, table=True):
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
return self.expiry_date is not None and datetime.utcnow() < self.expiry_date
|
||||
return self.expiry_date is not None and datetime.utcnow() > self.expiry_date
|
||||
|
||||
@@ -8,98 +8,39 @@ from uuid import UUID, uuid4
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats, SystemStats
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User] = None,
|
||||
user_repository: Optional[UserRepository] = None,
|
||||
task_repository: Optional[TaskRepository] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user = self.lookup_user(user)
|
||||
self.user_repository = user_repository or UserRepository(db, api_client)
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.task_repository = task_repository or TaskRepository(
|
||||
db, api_client, client_user, user_repository=self.user_repository
|
||||
)
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
|
||||
@@ -113,20 +54,48 @@ class PromptRepository:
|
||||
)
|
||||
return message
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
self.validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
return task
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def store_text_reply(self, text: str, frontend_message_id: str, user_frontend_message_id: str) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
@@ -174,7 +143,7 @@ class PromptRepository:
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
task = self.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
@@ -201,7 +170,7 @@ class PromptRepository:
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch task
|
||||
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -255,142 +224,6 @@ class PromptRepository:
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
@@ -515,28 +348,6 @@ class PromptRepository:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@staticmethod
|
||||
def trace_conversation(messages: list[Message] | dict[UUID, Message], last_message: Message) -> list[Message]:
|
||||
"""
|
||||
@@ -728,24 +539,3 @@ class PromptRepository:
|
||||
deleted=result.get(True, 0),
|
||||
message_trees=result.get(None, 0),
|
||||
)
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import oasst_backend.models.db_payload as db_payload
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
def validate_frontend_message_id(message_id: str) -> None:
|
||||
# TODO: Should it be replaced with fastapi/pydantic validation?
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
|
||||
class TaskRepository:
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
api_client: ApiClient,
|
||||
client_user: Optional[protocol_schema.User],
|
||||
user_repository: UserRepository,
|
||||
):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.user_repository = user_repository
|
||||
self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
payload = db_payload.SummarizationStoryPayload(story=task.story)
|
||||
|
||||
case protocol_schema.RateSummaryTask:
|
||||
payload = db_payload.RateSummaryPayload(
|
||||
full_text=task.full_text, summary=task.summary, scale=task.scale
|
||||
)
|
||||
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(type=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.RankAssistantRepliesTask:
|
||||
payload = db_payload.RankAssistantRepliesPayload(
|
||||
type=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
case protocol_schema.LabelInitialPromptTask:
|
||||
payload = db_payload.LabelInitialPromptPayload(
|
||||
type=task.type, message_id=task.message_id, prompt=task.prompt, valid_labels=task.valid_labels
|
||||
)
|
||||
|
||||
case protocol_schema.LabelPrompterReplyTask:
|
||||
payload = db_payload.LabelPrompterReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case protocol_schema.LabelAssistantReplyTask:
|
||||
payload = db_payload.LabelAssistantReplyPayload(
|
||||
type=task.type,
|
||||
message_id=task.message_id,
|
||||
conversation=task.conversation,
|
||||
reply=task.reply,
|
||||
valid_labels=task.valid_labels,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
"""
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
task = Task(
|
||||
id=id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return task
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import Optional
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user is None:
|
||||
if create_missing:
|
||||
# user is unknown, create new record
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def get_user_leaderboard(self, role: str) -> LeaderboardStats:
|
||||
"""
|
||||
Get leaderboard stats for Messages created,
|
||||
separate leaderboard for prompts & assistants
|
||||
|
||||
"""
|
||||
query = (
|
||||
self.db.query(Message.user_id, User.username, User.display_name, func.count(Message.user_id))
|
||||
.join(User, User.id == Message.user_id, isouter=True)
|
||||
.filter(Message.deleted is not True, Message.role == role)
|
||||
.group_by(Message.user_id, User.username, User.display_name)
|
||||
.order_by(func.count(Message.user_id).desc())
|
||||
)
|
||||
|
||||
result = [
|
||||
{"ranking": i, "user_id": j[0], "username": j[1], "display_name": j[2], "score": j[3]}
|
||||
for i, j in enumerate(query.all(), start=1)
|
||||
]
|
||||
|
||||
return LeaderboardStats(leaderboard=result)
|
||||
@@ -1,7 +1,63 @@
|
||||
# General
|
||||
# Research
|
||||
|
||||
This page lists research papers that are relevant to the project.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- Reinforcement Learning from Human Feedback
|
||||
- Generating Text From Language Models
|
||||
- Automatically Generating Instruction Data for Training
|
||||
- Uncertainty Estimation of Language Model Outputs
|
||||
|
||||
## Reinforcement Learning from Human Feedback <a name="reinforcement-learning-from-human-feedback"></a>
|
||||
|
||||
Reinforcement Learning from Human Feedback (RLHF) is a method for fine-tuning a
|
||||
generative language models based on a reward model that is learned from human
|
||||
preference data. This method facilitates the learning of instruction-tuned
|
||||
models, among other things.
|
||||
|
||||
### Learning to summarize from human feedback [[ArXiv](https://arxiv.org/pdf/2009.01325.pdf)], [[Github](https://github.com/openai/summarize-from-feedback)]
|
||||
|
||||
> In this work, we show that it is possible to significantly improve summary
|
||||
> quality by training a model to optimize for human preferences. We collect a
|
||||
> large, high-quality dataset of human comparisons between summaries, train a
|
||||
> model to predict the human-preferred summary, and use that model as a reward
|
||||
> function to fine-tune a summarization policy using reinforcement learning.
|
||||
|
||||
### Training language models to follow instructions with human feedback [[ArXiv](https://arxiv.org/pdf/2203.02155.pdf)]
|
||||
|
||||
> Starting with a set of labeler-written prompts and prompts submitted through
|
||||
> the OpenAI API, we collect a dataset of labeler demonstrations of the desired
|
||||
> model behavior, which we use to fine-tune GPT-3 using supervised learning. We
|
||||
> then collect a dataset of rankings of model outputs, which we use to further
|
||||
> fine-tune this supervised model using reinforcement learning from human
|
||||
> feedback.
|
||||
|
||||
### Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback [[ArXiv](https://arxiv.org/pdf/2204.05862.pdf)]
|
||||
|
||||
> We apply preference modeling and reinforcement learning from human feedback
|
||||
> (RLHF) to finetune language models to act as helpful and harmless assistants.
|
||||
> We find this alignment training improves performance on almost all NLP
|
||||
> evaluations, and is fully compatible with training for specialized skills such
|
||||
> as python coding and summarization.
|
||||
|
||||
## Generating Text From Language Models
|
||||
|
||||
A language model generates output text token by token, autoregressively. The
|
||||
large search space of this task requires some method of narrowing down the set
|
||||
of tokens to be considered in each step. This method, in turn, has a big impact
|
||||
on the quality of the resulting text.
|
||||
|
||||
### RANKGEN: Improving Text Generation with Large Ranking Models [[ArXiv](https://arxiv.org/pdf/2205.09726.pdf)], [[Github](https://github.com/martiansideofthemoon/rankgen)]
|
||||
|
||||
> Given an input sequence (or prefix), modern language models often assign high
|
||||
> probabilities to output sequences that are repetitive, incoherent, or
|
||||
> irrelevant to the prefix; as such, model-generated text also contains such
|
||||
> artifacts. To address these issues we present RankGen, a 1.2B parameter
|
||||
> encoder model for English that scores model generations given a prefix.
|
||||
> RankGen can be flexibly incorporated as a scoring function in beam search and
|
||||
> used to decode from any pretrained language model.
|
||||
|
||||
## Automatically Generating Instruction Data for Training
|
||||
|
||||
This line of work is about significantly reducing the need for manually
|
||||
@@ -32,3 +88,15 @@ models.
|
||||
> rivals the effectiveness of training on open-source manually-curated datasets,
|
||||
> surpassing the performance of models such as T0++ and Tk-Instruct across
|
||||
> various benchmarks.
|
||||
|
||||
## Uncertainty Estimation of Language Model Outputs
|
||||
|
||||
### Teaching models to express their uncertainty in words [[Arxiv](https://arxiv.org/pdf/2205.14334.pdf)]
|
||||
|
||||
> We show that a GPT-3 model can learn to express uncertainty about its own
|
||||
> answers in natural language -- without use of model logits. When given a
|
||||
> question, the model generates both an answer and a level of confidence (e.g.
|
||||
> "90% confidence" or "high confidence"). These levels map to probabilities that
|
||||
> are well calibrated. The model also remains moderately calibrated under
|
||||
> distribution shift, and is sensitive to uncertainty in its own answers, rather
|
||||
> than imitating human examples.
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -35,7 +35,7 @@ class RankGenCollator:
|
||||
max_length: Optional[int] = None
|
||||
max_examples: Optional[int] = None
|
||||
|
||||
def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
|
||||
def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
|
||||
prefixes = []
|
||||
better_answers = []
|
||||
worse_answers = []
|
||||
@@ -193,3 +193,47 @@ class HFSummary(Dataset):
|
||||
valid_idx = np.random.choice(len(rows), self.max_comparison_per_sample)
|
||||
# optimize the format later
|
||||
return context + self.postfix_prompt, [r for idx, r in enumerate(rows) if idx in valid_idx]
|
||||
|
||||
|
||||
class HFDataset(Dataset):
|
||||
"""
|
||||
This is a base huggingface dataset which written to support the
|
||||
simplest pos-neg pair format
|
||||
|
||||
we should do something like this for supervised datasets
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dataset_name, question_field, pos_answer_field, neg_answer_field, subset=None, split=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
dataset = load_dataset(dataset_name, subset)
|
||||
if split is not None:
|
||||
dataset = dataset[split]
|
||||
|
||||
self.questions = {}
|
||||
self.index2question = {}
|
||||
for row in dataset:
|
||||
question = row[question_field].strip()
|
||||
pos = row[pos_answer_field]
|
||||
neg = row[neg_answer_field]
|
||||
if question not in self.index2question:
|
||||
self.index2question[len(self.index2question)] = question
|
||||
|
||||
if question not in self.questions:
|
||||
self.questions[question] = []
|
||||
self.questions[question].append((pos.strip(), neg.strip()))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.index2question)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question = self.index2question[index]
|
||||
rows = self.questions[question]
|
||||
# optimize the format later
|
||||
return question, rows
|
||||
|
||||
|
||||
class GPTJSynthetic(HFDataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Dahoas/synthetic-instruct-gptj-pairwise", "prompt", "chosen", "rejected", None, "train")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from experimental_dataset import DataCollatorForSummaryScore, HFSummaryQuality
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, WebGPT
|
||||
from rank_datasets import DataCollatorForPairRank, GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_webgpt():
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
def test_hf_quality():
|
||||
def test_hf_summary_quality():
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForSummaryScore(tokenizer, max_length=200)
|
||||
@@ -35,6 +35,12 @@ def test_hf_quality():
|
||||
print(batch["input_ids"].shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hf_quality()
|
||||
# test_webgpt()
|
||||
def test_gptj_dataset():
|
||||
dataset = GPTJSynthetic()
|
||||
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
|
||||
collate_fn = DataCollatorForPairRank(tokenizer, max_length=1024)
|
||||
|
||||
print(len(dataset))
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32)
|
||||
for batch in dataloader:
|
||||
batch["input_ids"].shape
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from models import RankGenModel
|
||||
from rank_datasets import DataCollatorForPairRank, HFSummary, RankGenCollator, WebGPT
|
||||
from rank_datasets import DataCollatorForPairRank, RankGenCollator
|
||||
from torch import nn
|
||||
from torch.utils.data import ConcatDataset, Dataset
|
||||
from transformers import (
|
||||
AdamW,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_tokenizer, train_val_dataset
|
||||
from utils import argument_parsing, freeze_top_n_layers, get_datasets, get_tokenizer
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "reward-model"
|
||||
|
||||
@@ -32,11 +26,6 @@ parser = ArgumentParser()
|
||||
parser.add_argument("config", type=str)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "rank"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
@@ -60,31 +49,12 @@ class RankTrainer(Trainer):
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
model_name: str = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
loss_function: str = "rank",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
self.loss_fct = RankLoss() if args.loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = args.loss_function
|
||||
super().__init__(model, args, **kwargs)
|
||||
self.loss_fct = RankLoss() if loss_function == "rank" else nn.CrossEntropyLoss()
|
||||
self.loss_function = loss_function
|
||||
self.model_name = model_name
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
@@ -163,11 +133,10 @@ if __name__ == "__main__":
|
||||
params = sum([np.prod(p.size()) for p in model_parameters])
|
||||
print("Number of trainable : {}M".format(int(params / 1e6)))
|
||||
|
||||
args = CustomTrainingArguments(
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{model_name}-finetuned",
|
||||
num_train_epochs=training_conf["num_train_epochs"],
|
||||
warmup_steps=500,
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=training_conf["fp16"],
|
||||
@@ -184,22 +153,9 @@ if __name__ == "__main__":
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
if "hfsummary" in training_conf["datasets"]:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
|
||||
train, evals = get_datasets(training_conf["datasets"])
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
|
||||
else:
|
||||
@@ -224,8 +180,9 @@ if __name__ == "__main__":
|
||||
model=model,
|
||||
model_name=model_name,
|
||||
args=args,
|
||||
loss_function=training_conf["loss"],
|
||||
train_dataset=train,
|
||||
eval_dataset=eval,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import re
|
||||
from typing import AnyStr, List
|
||||
|
||||
import yaml
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
from transformers import AutoTokenizer, T5Tokenizer
|
||||
|
||||
re_reference_remove = re.compile(r"\[([0-9])+\]|\[([0-9])+,([0-9])+\]")
|
||||
# @agoryuno contributed this
|
||||
re_reference_remove = re.compile(r"\[\d+(?:,\s*\d+)*?\]")
|
||||
|
||||
|
||||
def webgpt_return_format(row):
|
||||
@@ -97,6 +99,32 @@ def argument_parsing(parser):
|
||||
return params
|
||||
|
||||
|
||||
def get_datasets(dataset_list: List[AnyStr]):
|
||||
from rank_datasets import GPTJSynthetic, HFSummary, WebGPT
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
train_datasets, evals = [], {}
|
||||
for dataset_name in dataset_list:
|
||||
if "webgpt" == dataset_name:
|
||||
web_dataset = WebGPT()
|
||||
train, eval = train_val_dataset(web_dataset, 0.2)
|
||||
train_datasets.append(train)
|
||||
evals["webgpt"] = eval
|
||||
elif "hfsummary" == dataset_name:
|
||||
sum_train = HFSummary(split="train")
|
||||
train_datasets.append(sum_train)
|
||||
sum_eval = HFSummary(split="valid1")
|
||||
assert len(sum_eval) > 0
|
||||
evals["hfsummary"] = sum_eval
|
||||
elif "gptsynthetic" == dataset_name:
|
||||
dataset = GPTJSynthetic()
|
||||
train, eval = train_val_dataset(dataset, 0.1)
|
||||
train_datasets.append(train)
|
||||
evals["gptsynthetic"] = eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
return train, evals
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
@@ -272,36 +272,29 @@ class TextLabel(str, enum.Enum):
|
||||
obj.help_text = help_text
|
||||
return obj
|
||||
|
||||
spam = "spam"
|
||||
spam = "spam", "Seems to be intentionally low-quality or irrelevant"
|
||||
fails_task = "fails_task", "Fails to follow the correct instruction / task"
|
||||
not_appropriate = "not_appropriate", "Inappropriate for customer assistant"
|
||||
violence = "violence", "Encourages or fails to discourage violence/abuse/terrorism/self-harm"
|
||||
harmful = (
|
||||
"harmful",
|
||||
"Harmful content",
|
||||
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
|
||||
excessive_harm = (
|
||||
"excessive_harm",
|
||||
"Content likely to cause excessive harm not justifiable in the context",
|
||||
"Harm refers to physical or mental damage or injury to someone or something. Excessive refers to a reasonable threshold of harm in the context, for instance damaging skin is not excessive in the context of surgery.",
|
||||
)
|
||||
sexual_content = "sexual_content", "Contains sexual content"
|
||||
toxicity = "toxicity"
|
||||
toxicity = "toxicity", "Contains rude, abusive, profane or insulting content"
|
||||
moral_judgement = "moral_judgement", "Expresses moral judgement"
|
||||
political_content = "political_content"
|
||||
humor = "humor"
|
||||
sarcasm = "sarcasm"
|
||||
hate_speech = "hate_speech"
|
||||
profanity = "profanity"
|
||||
ad_hominem = "ad_hominem"
|
||||
insult = "insult"
|
||||
threat = "threat"
|
||||
aggressive = "aggressive"
|
||||
misleading = "misleading"
|
||||
helpful = "helpful"
|
||||
formal = "formal"
|
||||
cringe = "cringe"
|
||||
creative = "creative"
|
||||
beautiful = "beautiful"
|
||||
informative = "informative"
|
||||
based = "based"
|
||||
slang = "slang"
|
||||
political_content = "political_content", "Expresses political views"
|
||||
humor = "humor", "Contains humorous content including sarcasm"
|
||||
hate_speech = (
|
||||
"hate_speech",
|
||||
"Content is abusive or threatening and expresses prejudice against a protected characteristic",
|
||||
"Prejudice refers to preconceived views not based on reason. Protected characteristics include gender, ethnicity, religion, sexual orientation, and similar characteristics.",
|
||||
)
|
||||
threat = "threat", "Contains a threat against a person or persons"
|
||||
misleading = "misleading", "Contains text which is incorrect or misleading"
|
||||
helpful = "helpful", "Completes the task to a high standard"
|
||||
creative = "creative", "Expresses creativity in responding to the task"
|
||||
|
||||
|
||||
class TextLabels(Interaction):
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
"rules": {
|
||||
"unused-imports/no-unused-imports": "warn",
|
||||
"simple-import-sort/imports": "warn",
|
||||
"simple-import-sort/exports": "warn"
|
||||
"simple-import-sort/exports": "warn",
|
||||
"eqeqeq": "warn"
|
||||
},
|
||||
"plugins": ["simple-import-sort", "unused-imports"]
|
||||
}
|
||||
|
||||
@@ -39,5 +39,7 @@ next-env.d.ts
|
||||
*.swp
|
||||
|
||||
# cypress
|
||||
/cypress/screenshots
|
||||
/cypress/videos
|
||||
/cypress-visual-screenshots/diff
|
||||
/cypress-visual-screenshots/comparison
|
||||
|
||||
@@ -153,6 +153,34 @@ When writing code for the website, we have a few best practices:
|
||||
1. Define functional React components (with types for all properties when
|
||||
feasible).
|
||||
|
||||
### Developing New Features
|
||||
|
||||
When working on new features or making significant changes that can't be done
|
||||
within a single Pull Request, we ask that you make use of Feature Flags.
|
||||
|
||||
We've set up
|
||||
[`react-feature-flags`](https://www.npmjs.com/package/react-feature-flags) to
|
||||
make this easier. To get started:
|
||||
|
||||
1. Add a new flag entry to `website/src/flags.ts`. We have an example flag you
|
||||
can copy as an example. Be sure to `isActive` to true when testing your
|
||||
features but false when submitting your PR.
|
||||
1. Use your flag wherever you add a new UI element. This can be done with:
|
||||
|
||||
```js
|
||||
import { Flags } from "react-feature-flags";
|
||||
...
|
||||
<Flags authorizedFlags={["yourFlagName"]}>
|
||||
<YourNewComponent />
|
||||
</Flags>
|
||||
```
|
||||
|
||||
You can see an example of how this works by checking `website/src/components/Header/Headers.tsx` where we use `flagTest`.
|
||||
|
||||
1. Once you've finished building out the feature and it is ready for everyone
|
||||
to use, it's safe to remove the `Flag` wrappers around your component and
|
||||
the entry in `flags.ts`.
|
||||
|
||||
### URL Paths
|
||||
|
||||
To use stable and consistent URL paths, we recommend the following strategy for
|
||||
|
||||
Generated
+375
-10
@@ -29,6 +29,7 @@
|
||||
"eslint-config-next": "13.0.6",
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
@@ -38,6 +39,7 @@
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-icons": "^4.7.1",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
@@ -56,7 +58,7 @@
|
||||
"@storybook/manager-webpack5": "^6.5.15",
|
||||
"@storybook/react": "^6.5.15",
|
||||
"@storybook/testing-library": "^0.0.13",
|
||||
"@types/node": "18.11.17",
|
||||
"@types/node": "^18.11.17",
|
||||
"@types/react": "18.0.26",
|
||||
"@typescript-eslint/eslint-plugin": "^5.47.1",
|
||||
"babel-loader": "^8.3.0",
|
||||
@@ -66,7 +68,8 @@
|
||||
"eslint-plugin-unused-imports": "^2.0.0",
|
||||
"prettier": "2.8.1",
|
||||
"prisma": "^4.7.1",
|
||||
"typescript": "4.9.4"
|
||||
"ts-node": "^10.9.1",
|
||||
"typescript": "^4.9.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@ampproject/remapping": {
|
||||
@@ -3122,6 +3125,28 @@
|
||||
"node": ">=0.1.90"
|
||||
}
|
||||
},
|
||||
"node_modules/@cspotcode/source-map-support": {
|
||||
"version": "0.8.1",
|
||||
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
|
||||
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/trace-mapping": "0.3.9"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
|
||||
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/resolve-uri": "^3.0.3",
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
},
|
||||
"node_modules/@cypress/request": {
|
||||
"version": "2.88.10",
|
||||
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
|
||||
@@ -10836,6 +10861,30 @@
|
||||
"@testing-library/dom": ">=7.21.4"
|
||||
}
|
||||
},
|
||||
"node_modules/@tsconfig/node10": {
|
||||
"version": "1.0.9",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
|
||||
"integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@tsconfig/node12": {
|
||||
"version": "1.0.11",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
|
||||
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@tsconfig/node14": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
|
||||
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@tsconfig/node16": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
|
||||
"integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@types/aria-query": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
|
||||
@@ -10975,7 +11024,7 @@
|
||||
"version": "18.11.17",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
|
||||
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
|
||||
"dev": true
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/@types/node-fetch": {
|
||||
"version": "2.6.2",
|
||||
@@ -12013,7 +12062,7 @@
|
||||
"version": "4.1.3",
|
||||
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
|
||||
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
|
||||
"dev": true
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/argparse": {
|
||||
"version": "1.0.10",
|
||||
@@ -14691,6 +14740,12 @@
|
||||
"sha.js": "^2.4.8"
|
||||
}
|
||||
},
|
||||
"node_modules/create-require": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
|
||||
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/cross-spawn": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
|
||||
@@ -15462,6 +15517,15 @@
|
||||
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
||||
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
|
||||
},
|
||||
"node_modules/diff": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
|
||||
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
|
||||
"devOptional": true,
|
||||
"engines": {
|
||||
"node": ">=0.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/diffie-hellman": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
|
||||
@@ -17688,6 +17752,47 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/formik": {
|
||||
"version": "2.2.9",
|
||||
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
|
||||
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://opencollective.com/formik"
|
||||
}
|
||||
],
|
||||
"dependencies": {
|
||||
"deepmerge": "^2.1.1",
|
||||
"hoist-non-react-statics": "^3.3.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"react-fast-compare": "^2.0.1",
|
||||
"tiny-warning": "^1.0.2",
|
||||
"tslib": "^1.10.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/formik/node_modules/deepmerge": {
|
||||
"version": "2.2.1",
|
||||
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
|
||||
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/formik/node_modules/react-fast-compare": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
|
||||
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
|
||||
},
|
||||
"node_modules/formik/node_modules/tslib": {
|
||||
"version": "1.14.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
|
||||
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
|
||||
},
|
||||
"node_modules/forwarded": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
|
||||
@@ -20409,8 +20514,12 @@
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"node_modules/lodash-es": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
|
||||
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
|
||||
},
|
||||
"node_modules/lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -20690,6 +20799,12 @@
|
||||
"semver": "bin/semver"
|
||||
}
|
||||
},
|
||||
"node_modules/make-error": {
|
||||
"version": "1.3.6",
|
||||
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
|
||||
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/makeerror": {
|
||||
"version": "1.0.12",
|
||||
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
|
||||
@@ -26310,6 +26425,16 @@
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
|
||||
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
|
||||
},
|
||||
"node_modules/react-feature-flags": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
|
||||
"integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
|
||||
"peerDependencies": {
|
||||
"prop-types": "^15.5.4",
|
||||
"react": ">= 16.3.0",
|
||||
"react-dom": ">= 16.3.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-focus-lock": {
|
||||
"version": "2.9.2",
|
||||
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
|
||||
@@ -29081,6 +29206,11 @@
|
||||
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
|
||||
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
|
||||
},
|
||||
"node_modules/tiny-warning": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
|
||||
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
|
||||
},
|
||||
"node_modules/tmp": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
|
||||
@@ -29247,6 +29377,70 @@
|
||||
"node": ">=6.10"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node": {
|
||||
"version": "10.9.1",
|
||||
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
|
||||
"integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"@cspotcode/source-map-support": "^0.8.0",
|
||||
"@tsconfig/node10": "^1.0.7",
|
||||
"@tsconfig/node12": "^1.0.7",
|
||||
"@tsconfig/node14": "^1.0.0",
|
||||
"@tsconfig/node16": "^1.0.2",
|
||||
"acorn": "^8.4.1",
|
||||
"acorn-walk": "^8.1.1",
|
||||
"arg": "^4.1.0",
|
||||
"create-require": "^1.1.0",
|
||||
"diff": "^4.0.1",
|
||||
"make-error": "^1.1.1",
|
||||
"v8-compile-cache-lib": "^3.0.1",
|
||||
"yn": "3.1.1"
|
||||
},
|
||||
"bin": {
|
||||
"ts-node": "dist/bin.js",
|
||||
"ts-node-cwd": "dist/bin-cwd.js",
|
||||
"ts-node-esm": "dist/bin-esm.js",
|
||||
"ts-node-script": "dist/bin-script.js",
|
||||
"ts-node-transpile-only": "dist/bin-transpile.js",
|
||||
"ts-script": "dist/bin-script-deprecated.js"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@swc/core": ">=1.2.50",
|
||||
"@swc/wasm": ">=1.2.50",
|
||||
"@types/node": "*",
|
||||
"typescript": ">=2.7"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@swc/core": {
|
||||
"optional": true
|
||||
},
|
||||
"@swc/wasm": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node/node_modules/acorn": {
|
||||
"version": "8.8.1",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
|
||||
"integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
|
||||
"devOptional": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-node/node_modules/acorn-walk": {
|
||||
"version": "8.2.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
|
||||
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
|
||||
"devOptional": true,
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ts-pnp": {
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
|
||||
@@ -29951,6 +30145,12 @@
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
|
||||
"devOptional": true
|
||||
},
|
||||
"node_modules/v8-to-istanbul": {
|
||||
"version": "9.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
|
||||
@@ -30850,6 +31050,15 @@
|
||||
"fd-slicer": "~1.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/yn": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
|
||||
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
|
||||
"devOptional": true,
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/yocto-queue": {
|
||||
"version": "0.1.0",
|
||||
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
|
||||
@@ -33044,6 +33253,27 @@
|
||||
"dev": true,
|
||||
"optional": true
|
||||
},
|
||||
"@cspotcode/source-map-support": {
|
||||
"version": "0.8.1",
|
||||
"resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz",
|
||||
"integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==",
|
||||
"devOptional": true,
|
||||
"requires": {
|
||||
"@jridgewell/trace-mapping": "0.3.9"
|
||||
},
|
||||
"dependencies": {
|
||||
"@jridgewell/trace-mapping": {
|
||||
"version": "0.3.9",
|
||||
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz",
|
||||
"integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==",
|
||||
"devOptional": true,
|
||||
"requires": {
|
||||
"@jridgewell/resolve-uri": "^3.0.3",
|
||||
"@jridgewell/sourcemap-codec": "^1.4.10"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"@cypress/request": {
|
||||
"version": "2.88.10",
|
||||
"resolved": "https://registry.npmjs.org/@cypress/request/-/request-2.88.10.tgz",
|
||||
@@ -38884,6 +39114,30 @@
|
||||
"@babel/runtime": "^7.12.5"
|
||||
}
|
||||
},
|
||||
"@tsconfig/node10": {
|
||||
"version": "1.0.9",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz",
|
||||
"integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==",
|
||||
"devOptional": true
|
||||
},
|
||||
"@tsconfig/node12": {
|
||||
"version": "1.0.11",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz",
|
||||
"integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==",
|
||||
"devOptional": true
|
||||
},
|
||||
"@tsconfig/node14": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz",
|
||||
"integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==",
|
||||
"devOptional": true
|
||||
},
|
||||
"@tsconfig/node16": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz",
|
||||
"integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"@types/aria-query": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/aria-query/-/aria-query-5.0.1.tgz",
|
||||
@@ -39023,7 +39277,7 @@
|
||||
"version": "18.11.17",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-18.11.17.tgz",
|
||||
"integrity": "sha512-HJSUJmni4BeDHhfzn6nF0sVmd1SMezP7/4F0Lq+aXzmp2xm9O7WXrUtHW/CHlYVtZUbByEvWidHqRtcJXGF2Ng==",
|
||||
"dev": true
|
||||
"devOptional": true
|
||||
},
|
||||
"@types/node-fetch": {
|
||||
"version": "2.6.2",
|
||||
@@ -39875,7 +40129,7 @@
|
||||
"version": "4.1.3",
|
||||
"resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz",
|
||||
"integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==",
|
||||
"dev": true
|
||||
"devOptional": true
|
||||
},
|
||||
"argparse": {
|
||||
"version": "1.0.10",
|
||||
@@ -41964,6 +42218,12 @@
|
||||
"sha.js": "^2.4.8"
|
||||
}
|
||||
},
|
||||
"create-require": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz",
|
||||
"integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==",
|
||||
"devOptional": true
|
||||
},
|
||||
"cross-spawn": {
|
||||
"version": "7.0.3",
|
||||
"resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz",
|
||||
@@ -42547,6 +42807,12 @@
|
||||
"resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz",
|
||||
"integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw=="
|
||||
},
|
||||
"diff": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
|
||||
"integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==",
|
||||
"devOptional": true
|
||||
},
|
||||
"diffie-hellman": {
|
||||
"version": "5.0.3",
|
||||
"resolved": "https://registry.npmjs.org/diffie-hellman/-/diffie-hellman-5.0.3.tgz",
|
||||
@@ -44293,6 +44559,37 @@
|
||||
"mime-types": "^2.1.12"
|
||||
}
|
||||
},
|
||||
"formik": {
|
||||
"version": "2.2.9",
|
||||
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
|
||||
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
|
||||
"requires": {
|
||||
"deepmerge": "^2.1.1",
|
||||
"hoist-non-react-statics": "^3.3.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"react-fast-compare": "^2.0.1",
|
||||
"tiny-warning": "^1.0.2",
|
||||
"tslib": "^1.10.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"deepmerge": {
|
||||
"version": "2.2.1",
|
||||
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
|
||||
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA=="
|
||||
},
|
||||
"react-fast-compare": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
|
||||
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
|
||||
},
|
||||
"tslib": {
|
||||
"version": "1.14.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
|
||||
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
|
||||
}
|
||||
}
|
||||
},
|
||||
"forwarded": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
|
||||
@@ -46307,8 +46604,12 @@
|
||||
"lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"lodash-es": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
|
||||
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
|
||||
},
|
||||
"lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -46525,6 +46826,12 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"make-error": {
|
||||
"version": "1.3.6",
|
||||
"resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz",
|
||||
"integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==",
|
||||
"devOptional": true
|
||||
},
|
||||
"makeerror": {
|
||||
"version": "1.0.12",
|
||||
"resolved": "https://registry.npmjs.org/makeerror/-/makeerror-1.0.12.tgz",
|
||||
@@ -50542,6 +50849,12 @@
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-3.2.0.tgz",
|
||||
"integrity": "sha512-rtGImPZ0YyLrscKI9xTpV8psd6I8VAtjKCzQDlzyDvqJA8XOW78TXYQwNRNd8g8JZnDu8q9Fu/1v4HPAVwVdHA=="
|
||||
},
|
||||
"react-feature-flags": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/react-feature-flags/-/react-feature-flags-1.0.0.tgz",
|
||||
"integrity": "sha512-KBFUkXjF7ifGWEQr2Ida4LdAtKGDOwFdTRlXipWxGP9a43vUBxP6IscpYQofGjlzlBcgmFKuzubcVheB6NliEg==",
|
||||
"requires": {}
|
||||
},
|
||||
"react-focus-lock": {
|
||||
"version": "2.9.2",
|
||||
"resolved": "https://registry.npmjs.org/react-focus-lock/-/react-focus-lock-2.9.2.tgz",
|
||||
@@ -52722,6 +53035,11 @@
|
||||
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
|
||||
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
|
||||
},
|
||||
"tiny-warning": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
|
||||
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
|
||||
},
|
||||
"tmp": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
|
||||
@@ -52852,6 +53170,41 @@
|
||||
"integrity": "sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==",
|
||||
"dev": true
|
||||
},
|
||||
"ts-node": {
|
||||
"version": "10.9.1",
|
||||
"resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz",
|
||||
"integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==",
|
||||
"devOptional": true,
|
||||
"requires": {
|
||||
"@cspotcode/source-map-support": "^0.8.0",
|
||||
"@tsconfig/node10": "^1.0.7",
|
||||
"@tsconfig/node12": "^1.0.7",
|
||||
"@tsconfig/node14": "^1.0.0",
|
||||
"@tsconfig/node16": "^1.0.2",
|
||||
"acorn": "^8.4.1",
|
||||
"acorn-walk": "^8.1.1",
|
||||
"arg": "^4.1.0",
|
||||
"create-require": "^1.1.0",
|
||||
"diff": "^4.0.1",
|
||||
"make-error": "^1.1.1",
|
||||
"v8-compile-cache-lib": "^3.0.1",
|
||||
"yn": "3.1.1"
|
||||
},
|
||||
"dependencies": {
|
||||
"acorn": {
|
||||
"version": "8.8.1",
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.8.1.tgz",
|
||||
"integrity": "sha512-7zFpHzhnqYKrkYdUjF1HI1bzd0VygEGX8lFk4k5zVMqHEoES+P+7TKI+EvLO9WVMJ8eekdO0aDEK044xTXwPPA==",
|
||||
"devOptional": true
|
||||
},
|
||||
"acorn-walk": {
|
||||
"version": "8.2.0",
|
||||
"resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz",
|
||||
"integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==",
|
||||
"devOptional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"ts-pnp": {
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/ts-pnp/-/ts-pnp-1.2.0.tgz",
|
||||
@@ -53362,6 +53715,12 @@
|
||||
"integrity": "sha512-dsNgbLaTrd6l3MMxTtouOCFw4CBFc/3a+GgYA2YyrJvyQ1u6q4pcu3ktLoUZ/VN/Aw9WsauazbgsgdfVWgAKQg==",
|
||||
"dev": true
|
||||
},
|
||||
"v8-compile-cache-lib": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz",
|
||||
"integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==",
|
||||
"devOptional": true
|
||||
},
|
||||
"v8-to-istanbul": {
|
||||
"version": "9.0.1",
|
||||
"resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz",
|
||||
@@ -54083,6 +54442,12 @@
|
||||
"fd-slicer": "~1.1.0"
|
||||
}
|
||||
},
|
||||
"yn": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz",
|
||||
"integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==",
|
||||
"devOptional": true
|
||||
},
|
||||
"yocto-queue": {
|
||||
"version": "0.1.0",
|
||||
"resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz",
|
||||
|
||||
@@ -18,6 +18,9 @@
|
||||
"fix:format": "prettier --write ./src",
|
||||
"fix": "npm run fix:format && npm run fix:lint"
|
||||
},
|
||||
"prisma": {
|
||||
"seed": "ts-node --compiler-options {\"module\":\"CommonJS\"} prisma/seed.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react": "^2.4.4",
|
||||
"@dnd-kit/core": "^6.0.6",
|
||||
@@ -40,6 +43,7 @@
|
||||
"eslint-config-next": "13.0.6",
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
@@ -49,6 +53,7 @@
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-icons": "^4.7.1",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
@@ -67,7 +72,7 @@
|
||||
"@storybook/manager-webpack5": "^6.5.15",
|
||||
"@storybook/react": "^6.5.15",
|
||||
"@storybook/testing-library": "^0.0.13",
|
||||
"@types/node": "18.11.17",
|
||||
"@types/node": "^18.11.17",
|
||||
"@types/react": "18.0.26",
|
||||
"@typescript-eslint/eslint-plugin": "^5.47.1",
|
||||
"babel-loader": "^8.3.0",
|
||||
@@ -77,6 +82,7 @@
|
||||
"eslint-plugin-unused-imports": "^2.0.0",
|
||||
"prettier": "2.8.1",
|
||||
"prisma": "^4.7.1",
|
||||
"typescript": "4.9.4"
|
||||
"ts-node": "^10.9.1",
|
||||
"typescript": "^4.9.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
/**
|
||||
* A seed function to inject test data into the web database.
|
||||
*
|
||||
* Use by running
|
||||
* npx prisma db seed
|
||||
*/
|
||||
|
||||
import { PrismaClient } from "@prisma/client";
|
||||
const prisma = new PrismaClient();
|
||||
|
||||
async function main() {
|
||||
const users = [
|
||||
{ email: "general.user.a@example.com", name: "A", role: "general" },
|
||||
{ email: "general.user.b@example.com", name: "B", role: "general" },
|
||||
{ email: "general.user.c@example.com", name: "C", role: "general" },
|
||||
{ email: "general.user.d@example.com", name: "D", role: "general" },
|
||||
{ email: "general.user.e@example.com", name: "E", role: "general" },
|
||||
{ email: "general.user.f@example.com", name: "F", role: "general" },
|
||||
{ email: "general.user.g@example.com", name: "G", role: "general" },
|
||||
{ email: "general.user.h@example.com", name: "H", role: "general" },
|
||||
{ email: "general.user.i@example.com", name: "I", role: "general" },
|
||||
{ email: "general.user.j@example.com", name: "J", role: "general" },
|
||||
{ email: "general.user.k@example.com", name: "K", role: "general" },
|
||||
{ email: "general.user.l@example.com", name: "L", role: "general" },
|
||||
{ email: "general.user.m@example.com", name: "M", role: "general" },
|
||||
{ email: "general.user.n@example.com", name: "N", role: "general" },
|
||||
{ email: "general.user.o@example.com", name: "O", role: "general" },
|
||||
{ email: "general.user.p@example.com", name: "P", role: "general" },
|
||||
{ email: "general.user.q@example.com", name: "Q", role: "general" },
|
||||
{ email: "general.user.r@example.com", name: "R", role: "general" },
|
||||
{ email: "malicious.user.1@example.com", name: "M1", role: "general" },
|
||||
{ email: "malicious.user.2@example.com", name: "M2", role: "general" },
|
||||
];
|
||||
await Promise.all(
|
||||
users.map(async ({ email, name, role }) => {
|
||||
await prisma.user.upsert({
|
||||
where: { email },
|
||||
update: { name, role },
|
||||
create: {
|
||||
email,
|
||||
name,
|
||||
role,
|
||||
},
|
||||
});
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
main()
|
||||
.then(async () => {
|
||||
await prisma.$disconnect();
|
||||
})
|
||||
.catch(async (e) => {
|
||||
console.error(e);
|
||||
await prisma.$disconnect();
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -1,9 +1,63 @@
|
||||
import { Button, ButtonProps } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
ButtonProps,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Textarea,
|
||||
useDisclosure,
|
||||
} from "@chakra-ui/react";
|
||||
import { useState } from "react";
|
||||
import { FaChevronDown } from "react-icons/fa";
|
||||
|
||||
interface SkipButtonProps extends ButtonProps {
|
||||
onSkip: (reason: string) => void;
|
||||
}
|
||||
|
||||
export const SkipButton = ({ onSkip, ...props }: SkipButtonProps) => {
|
||||
const { isOpen, onOpen: showModal, onClose: closeModal } = useDisclosure();
|
||||
const [value, setValue] = useState("");
|
||||
|
||||
const onSubmit = () => {
|
||||
onSkip(value);
|
||||
setValue("");
|
||||
closeModal();
|
||||
};
|
||||
|
||||
export const SkipButton = ({ children, ...props }: ButtonProps) => {
|
||||
return (
|
||||
<Button size="lg" variant="outline" {...props}>
|
||||
{children}
|
||||
</Button>
|
||||
<>
|
||||
<Button size="lg" variant="outline" onClick={showModal} {...props}>
|
||||
Skip
|
||||
</Button>
|
||||
<Modal isOpen={isOpen} onClose={closeModal}>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>Skip</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody>
|
||||
<Textarea
|
||||
value={value}
|
||||
onChange={(e) => setValue(e.target.value)}
|
||||
resize="none"
|
||||
placeholder="Any feedback on this task?"
|
||||
/>
|
||||
</ModalBody>
|
||||
|
||||
<ModalFooter>
|
||||
<Button colorScheme="blue" mr={3} onClick={onSubmit}>
|
||||
Send
|
||||
</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -3,7 +3,7 @@ import Link from "next/link";
|
||||
|
||||
import { TaskCategory, TaskTypes } from "../Tasks/TaskTypes";
|
||||
|
||||
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate];
|
||||
const displayTaskCategories = [TaskCategory.Create, TaskCategory.Evaluate, TaskCategory.Label];
|
||||
|
||||
export const TaskOption = () => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
@@ -12,9 +12,9 @@ export const TaskOption = () => {
|
||||
<Box className="flex flex-col gap-14" fontFamily="inter">
|
||||
{displayTaskCategories.map((category, categoryIndex) => (
|
||||
<div key={categoryIndex}>
|
||||
<Text className="text-2xl font-bold pb-4">{TaskCategory[category]}</Text>
|
||||
<Text className="text-2xl font-bold pb-4">{category}</Text>
|
||||
<SimpleGrid columns={[1, 2, 2, 3, 4]} gap={4}>
|
||||
{TaskTypes.filter((task) => task.category == category).map((item, itemIndex) => (
|
||||
{TaskTypes.filter((task) => task.category === category).map((item, itemIndex) => (
|
||||
<Link key={itemIndex} href={item.pathname}>
|
||||
<GridItem
|
||||
bg={backgroundColor}
|
||||
|
||||
@@ -29,7 +29,7 @@ import useSWRMutation from "swr/mutation";
|
||||
|
||||
export const FlaggableElement = (props) => {
|
||||
const [isEditing, setIsEditing] = useBoolean();
|
||||
const { trigger } = useSWRMutation("/api/v1/text_labels", poster, {
|
||||
const { trigger } = useSWRMutation("/api/set_label", poster, {
|
||||
onSuccess: () => {
|
||||
setIsEditing.off;
|
||||
},
|
||||
@@ -42,7 +42,12 @@ export const FlaggableElement = (props) => {
|
||||
label_map.set(flag.attributeName, sliderValues[i]);
|
||||
}
|
||||
});
|
||||
trigger({ post_id: props.post_id, label_map: Object.fromEntries(label_map), text: props.text });
|
||||
trigger({
|
||||
message_id: props.message_id,
|
||||
post_id: props.post_id,
|
||||
label_map: Object.fromEntries(label_map),
|
||||
text: props.text,
|
||||
});
|
||||
};
|
||||
const [checkboxValues, setCheckboxValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(false));
|
||||
const [sliderValues, setSliderValues] = useState(new Array(TEXT_LABEL_FLAGS.length).fill(1));
|
||||
@@ -118,7 +123,8 @@ export const FlaggableElement = (props) => {
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
function FlagCheckbox(props: {
|
||||
|
||||
export function FlagCheckbox(props: {
|
||||
option: textFlagLabels;
|
||||
idx: number;
|
||||
checkboxValues: boolean[];
|
||||
@@ -183,40 +189,40 @@ interface textFlagLabels {
|
||||
const TEXT_LABEL_FLAGS: textFlagLabels[] = [
|
||||
// For the time being this list is configured on the FE.
|
||||
// In the future it may be provided by the API.
|
||||
// {
|
||||
// attributeName: "fails_task",
|
||||
// labelText: "Fails to follow the correct instruction / task",
|
||||
// additionalExplanation: "__TODO__",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "not_customer_assistant_appropriate",
|
||||
// labelText: "Inappropriate for customer assistant",
|
||||
// additionalExplanation: "__TODO__",
|
||||
// },
|
||||
{
|
||||
attributeName: "fails_task",
|
||||
labelText: "Fails to follow the correct instruction / task",
|
||||
additionalExplanation: "__TODO__",
|
||||
},
|
||||
{
|
||||
attributeName: "not_customer_assistant_appropriate",
|
||||
labelText: "Inappropriate for customer assistant",
|
||||
additionalExplanation: "__TODO__",
|
||||
},
|
||||
{
|
||||
attributeName: "contains_sexual_content",
|
||||
attributeName: "sexual_content",
|
||||
labelText: "Contains sexual content",
|
||||
},
|
||||
{
|
||||
attributeName: "contains_violent_content",
|
||||
attributeName: "violence",
|
||||
labelText: "Contains violent content",
|
||||
},
|
||||
{
|
||||
attributeName: "encourages_violence",
|
||||
labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
|
||||
},
|
||||
{
|
||||
attributeName: "denigrates_a_protected_class",
|
||||
labelText: "Denigrates a protected class",
|
||||
},
|
||||
{
|
||||
attributeName: "gives_harmful_advice",
|
||||
labelText: "Fails to follow the correct instruction / task",
|
||||
additionalExplanation:
|
||||
"The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
|
||||
},
|
||||
{
|
||||
attributeName: "expresses_moral_judgement",
|
||||
labelText: "Expresses moral judgement",
|
||||
},
|
||||
// {
|
||||
// attributeName: "encourages_violence",
|
||||
// labelText: "Encourages or fails to discourage violence/abuse/terrorism/self-harm",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "denigrates_a_protected_class",
|
||||
// labelText: "Denigrates a protected class",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "gives_harmful_advice",
|
||||
// labelText: "Fails to follow the correct instruction / task",
|
||||
// additionalExplanation:
|
||||
// "The advice given in the output is harmful or counter-productive. This may be in addition to, but is distinct from the question about encouraging violence/abuse/terrorism/self-harm.",
|
||||
// },
|
||||
// {
|
||||
// attributeName: "expresses_moral_judgement",
|
||||
// labelText: "Expresses moral judgement",
|
||||
// },
|
||||
];
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Box, Button, Text, useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { Flags } from "react-feature-flags";
|
||||
import { FaUser } from "react-icons/fa";
|
||||
|
||||
import { UserMenu } from "./UserMenu";
|
||||
@@ -42,6 +43,9 @@ export function Header(props) {
|
||||
</Link>
|
||||
</div>
|
||||
<div className="flex items-center gap-4">
|
||||
<Flags authorizedFlags={["flagTest"]}>
|
||||
<div>FlagTest</div>
|
||||
</Flags>
|
||||
<AccountButton />
|
||||
<UserMenu />
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// https://nextjs.org/docs/basic-features/layouts
|
||||
|
||||
import type { NextPage } from "next";
|
||||
import { FiLayout, FiMessageSquare } from "react-icons/fi";
|
||||
import { FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
import { Footer } from "./Footer";
|
||||
@@ -51,4 +51,22 @@ export const getDashboardLayout = (page: React.ReactElement) => (
|
||||
</div>
|
||||
);
|
||||
|
||||
export const getAdminLayout = (page: React.ReactElement) => (
|
||||
<div className="grid grid-rows-[min-content_1fr_min-content] h-full justify-items-stretch">
|
||||
<Header transparent={true} />
|
||||
<SideMenuLayout
|
||||
menuButtonOptions={[
|
||||
{
|
||||
label: "Users",
|
||||
pathname: "/admin",
|
||||
desc: "Users Dashboard",
|
||||
icon: FiUsers,
|
||||
},
|
||||
]}
|
||||
>
|
||||
{page}
|
||||
</SideMenuLayout>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const noLayout = (page: React.ReactElement) => page;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Progress } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
|
||||
export const LoadingScreen = ({ text }) => {
|
||||
export const LoadingScreen = ({ text = "Loading..." } = {}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
|
||||
@@ -1,36 +1,38 @@
|
||||
import { Grid } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { FlaggableElement } from "./FlaggableElement";
|
||||
|
||||
export interface Message {
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}
|
||||
|
||||
const getBgColor = (isAssistant: boolean, colorMode: "light" | "dark") => {
|
||||
if (colorMode === "light") {
|
||||
return isAssistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return isAssistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
};
|
||||
|
||||
export const Messages = ({ messages, post_id }: { messages: Message[]; post_id: string }) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const items = messages.map(({ text, is_assistant }: Message, i: number) => {
|
||||
const items = messages.map((messageProps: Message, i: number) => {
|
||||
const { message_id, text } = messageProps;
|
||||
return (
|
||||
<FlaggableElement text={text} post_id={post_id} key={i + text}>
|
||||
<div
|
||||
key={i + text}
|
||||
className={`${getBgColor(is_assistant, colorMode)} p-4 rounded-md text-white whitespace-pre-wrap`}
|
||||
>
|
||||
{text}
|
||||
</div>
|
||||
<FlaggableElement text={text} post_id={post_id} message_id={message_id} key={i + text}>
|
||||
<MessageView {...messageProps} />
|
||||
</FlaggableElement>
|
||||
);
|
||||
});
|
||||
// Maybe also show a legend of the colors?
|
||||
return <Grid gap={2}>{items}</Grid>;
|
||||
};
|
||||
|
||||
export const MessageView = ({ is_assistant, text, message_id }: Message) => {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const bgColor = useMemo(() => {
|
||||
if (colorMode === "light") {
|
||||
return is_assistant ? "bg-slate-800" : "bg-sky-900";
|
||||
} else {
|
||||
return is_assistant ? "bg-black" : "bg-sky-900";
|
||||
}
|
||||
}, [colorMode, is_assistant]);
|
||||
|
||||
return <div className={`${bgColor} p-4 rounded-md text-white whitespace-pre-wrap`}>{text}</div>;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import clsx from "clsx";
|
||||
import { SkipButton } from "src/components/Buttons/Skip";
|
||||
import { SubmitButton } from "src/components/Buttons/Submit";
|
||||
import { TaskInfo } from "src/components/TaskInfo/TaskInfo";
|
||||
@@ -10,31 +11,38 @@ export interface TaskControlsProps {
|
||||
tasks: any[];
|
||||
className?: string;
|
||||
onSubmitResponse: (task: { id: string }) => void;
|
||||
onSkip: () => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
}
|
||||
|
||||
export const TaskControls = (props: TaskControlsProps) => {
|
||||
const extraClases = props.className || "";
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const baseClasses = "flex flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto";
|
||||
const taskControlClases =
|
||||
colorMode === "light"
|
||||
? `${baseClasses} bg-white text-gray-800 shadow-lg ${extraClases}`
|
||||
: `${baseClasses} bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset ${extraClases}`;
|
||||
|
||||
const isLightMode = colorMode === "light";
|
||||
const endTask = props.tasks[props.tasks.length - 1];
|
||||
return (
|
||||
<section className={taskControlClases}>
|
||||
<section
|
||||
className={clsx(
|
||||
"flex-row justify-items-stretch mb-8 p-4 rounded-lg max-w-7xl mx-auto space-y-4 sm:space-y-0 sm:flex",
|
||||
props.className,
|
||||
{
|
||||
"bg-white text-gray-800 shadow-lg": isLightMode,
|
||||
"bg-slate-800 text-slate-400 shadow-xl ring-1 ring-white/10 ring-inset": !isLightMode,
|
||||
}
|
||||
)}
|
||||
>
|
||||
<TaskInfo id={props.tasks[0].id} output="Submit your answer" />
|
||||
<Flex justify="center" ml="auto" gap={2}>
|
||||
<SkipButton>Skip</SkipButton>
|
||||
<SkipButton
|
||||
onSkip={(reason: string) => {
|
||||
props.onSkipTask(props.tasks[0], reason);
|
||||
}}
|
||||
/>
|
||||
{endTask.task.type !== "task_done" ? (
|
||||
<SubmitButton colorScheme="blue" data-cy="submit" onClick={() => props.onSubmitResponse(props.tasks[0])}>
|
||||
Submit
|
||||
</SubmitButton>
|
||||
) : (
|
||||
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onSkip}>
|
||||
<SubmitButton colorScheme="green" data-cy="next-task" onClick={props.onNextTask}>
|
||||
Next Task
|
||||
</SubmitButton>
|
||||
)}
|
||||
|
||||
@@ -28,7 +28,7 @@ export const TrackedTextarea = (props: TrackedTextboxProps) => {
|
||||
|
||||
return (
|
||||
<Stack direction={"column"}>
|
||||
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} onCapture />
|
||||
<Textarea data-cy="reply" value={props.text} onChange={props.onTextChange} {...props.textareaProps} />
|
||||
<Progress size={"md"} rounded={"md"} value={wordCount} colorScheme={progressColor} max={props.thresholds.goal} />
|
||||
</Stack>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
export const TaskInfo = ({ id, output }: { id: string; output: string }) => {
|
||||
return (
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2 ">
|
||||
<div className="grid grid-cols-[min-content_auto] gap-x-2">
|
||||
<b>Prompt</b>
|
||||
<span data-cy="task-id">{id}</span>
|
||||
<b>Output</b>
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import { Card, CardBody, Flex, Heading } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import Link from "next/link";
|
||||
|
||||
export type OptionProps = {
|
||||
img: string;
|
||||
alt: string;
|
||||
title: string;
|
||||
link: string;
|
||||
};
|
||||
|
||||
export const TaskOption = (props: OptionProps) => {
|
||||
const { alt, img, title, link } = props;
|
||||
return (
|
||||
<Link href={link}>
|
||||
<Card
|
||||
maxW="300"
|
||||
minW="300"
|
||||
minH="300"
|
||||
maxH="300"
|
||||
className="transition ease-in-out duration-500 sm:grayscale hover:grayscale-0"
|
||||
>
|
||||
<CardBody width="full" height="full">
|
||||
<Flex direction="column" alignItems="center" justifyContent="center">
|
||||
<Image src={img} alt={alt} width={200} height={200} />
|
||||
<Heading
|
||||
mt={-10}
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
textAlign="center"
|
||||
fontSize="3xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
</Flex>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Link>
|
||||
);
|
||||
};
|
||||
@@ -1,23 +0,0 @@
|
||||
import { Divider, Flex, Heading } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
export type TaskOptionsProps = {
|
||||
title: string;
|
||||
children: JSX.Element | JSX.Element[];
|
||||
};
|
||||
|
||||
export const TaskOptions = (props: TaskOptionsProps) => {
|
||||
const { title, children } = props;
|
||||
return (
|
||||
<Flex gap={10} wrap="wrap" justifyContent="center">
|
||||
<Heading
|
||||
className="bg-gradient-to-r from-indigo-600 via-sky-400 to-indigo-700 bg-clip-text tracking-tight text-transparent"
|
||||
fontSize="5xl"
|
||||
>
|
||||
{title}
|
||||
</Heading>
|
||||
<Divider mt={-8} />
|
||||
{children}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,73 +0,0 @@
|
||||
import { Flex } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
|
||||
import { TaskOption } from "./TaskOption";
|
||||
import { TaskOptions } from "./TaskOptions";
|
||||
|
||||
export const TaskSelection = () => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
return (
|
||||
<Flex
|
||||
gap={10}
|
||||
wrap="wrap"
|
||||
justifyContent="space-evenly"
|
||||
width="full"
|
||||
height="full"
|
||||
alignItems={"center"}
|
||||
className={mainBgClasses}
|
||||
>
|
||||
<TaskOptions key="create" title="Create">
|
||||
{/* <TaskOption
|
||||
alt="Summarize Stories"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Summarize stories"
|
||||
link="/create/summarize_story"
|
||||
/> */}
|
||||
<TaskOption
|
||||
alt="Create Initial Prompt"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Create Initial Prompt"
|
||||
link="/create/initial_prompt"
|
||||
/>
|
||||
<TaskOption alt="Reply as User" img="/images/logos/logo.svg" title="Reply as User" link="/create/user_reply" />
|
||||
<TaskOption
|
||||
alt="Reply as Assistant"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Reply as Assistant"
|
||||
link="/create/assistant_reply"
|
||||
/>
|
||||
</TaskOptions>
|
||||
<TaskOptions key="evaluate" title="Evaluate">
|
||||
{/*
|
||||
Commented out while the backend does not support them.
|
||||
<TaskOption
|
||||
alt="Rate Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rate Prompts"
|
||||
link="/evaluate/rate_summary"
|
||||
/> */}
|
||||
<TaskOption
|
||||
alt="Rank Initial Prompts"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank Initial Prompts"
|
||||
link="/evaluate/rank_initial_prompts"
|
||||
/>
|
||||
<TaskOption
|
||||
alt="Rank User Replies"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank User Replies"
|
||||
link="/evaluate/rank_user_replies"
|
||||
/>
|
||||
<TaskOption
|
||||
alt="Rank Assistant Replies"
|
||||
img="/images/logos/logo.svg"
|
||||
title="Rank Assistant Replies"
|
||||
link="/evaluate/rank_assistant_replies"
|
||||
/>
|
||||
</TaskOptions>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,3 +0,0 @@
|
||||
export { TaskOption } from "./TaskOption";
|
||||
export { TaskOptions } from "./TaskOptions";
|
||||
export { TaskSelection } from "./TaskSelection";
|
||||
@@ -1,10 +1,22 @@
|
||||
import { useState } from "react";
|
||||
|
||||
import { Messages } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskType } from "./TaskTypes";
|
||||
|
||||
export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses }) => {
|
||||
export interface CreateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
taskType: TaskType;
|
||||
trigger: (update: { id: string; update_type: string; content: { text: string } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
export const CreateTask = ({ tasks, taskType, trigger, onSkipTask, onNextTask, mainBgClasses }: CreateTaskProps) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const [inputText, setInputText] = useState("");
|
||||
@@ -20,11 +32,6 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setInputText("");
|
||||
mutate();
|
||||
};
|
||||
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
setInputText(event.target.value);
|
||||
};
|
||||
@@ -48,7 +55,15 @@ export const CreateTask = ({ tasks, taskType, trigger, mutate, mainBgClasses })
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={(task, reason) => {
|
||||
setInputText("");
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -5,7 +5,17 @@ import { TaskControlsOverridable } from "src/components/Survey/TaskControlsOverr
|
||||
|
||||
import { MessageTable } from "../Messages/MessageTable";
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
export interface EvaluateTaskProps {
|
||||
// we need a task type
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
tasks: any[];
|
||||
trigger: (update: { id: string; update_type: string; content: { ranking: number[] } }) => void;
|
||||
onSkipTask: (task: { id: string }, reason: string) => void;
|
||||
onNextTask: () => void;
|
||||
mainBgClasses: string;
|
||||
}
|
||||
|
||||
export const EvaluateTask = ({ tasks, trigger, onSkipTask, onNextTask, mainBgClasses }: EvaluateTaskProps) => {
|
||||
const [ranking, setRanking] = useState<number[]>([]);
|
||||
const submitResponse = (task) => {
|
||||
trigger({
|
||||
@@ -17,10 +27,6 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
});
|
||||
};
|
||||
|
||||
const fetchNextTask = () => {
|
||||
setRanking([]);
|
||||
mutate();
|
||||
};
|
||||
let messages = null;
|
||||
if (tasks[0].task.conversation) {
|
||||
messages = tasks[0].task.conversation.messages;
|
||||
@@ -45,7 +51,11 @@ export const EvaluateTask = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
isValid={ranking.length == tasks[0].task[sortables].length}
|
||||
prepareForSubmit={() => setRanking(tasks[0].task[sortables].map((_, idx) => idx))}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkip={fetchNextTask}
|
||||
onSkipTask={(task, reason) => {
|
||||
setRanking([]);
|
||||
onSkipTask(task, reason);
|
||||
}}
|
||||
onNextTask={onNextTask}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
import { Grid, Slider, SliderFilledTrack, SliderThumb, SliderTrack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { ReactNode, useEffect, useId, useMemo, useState } from "react";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { colors } from "styles/Theme/colors";
|
||||
|
||||
export const LabelTask = ({
|
||||
title,
|
||||
desc,
|
||||
messages,
|
||||
inputs,
|
||||
controls,
|
||||
}: {
|
||||
title: string;
|
||||
desc: string;
|
||||
messages: ReactNode;
|
||||
inputs: ReactNode;
|
||||
controls: ReactNode;
|
||||
}) => {
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
|
||||
const card = useMemo(
|
||||
() => (
|
||||
<>
|
||||
<h5 className="text-lg font-semibold">{title}</h5>
|
||||
<p className="text-lg py-1">{desc}</p>
|
||||
{messages}
|
||||
</>
|
||||
),
|
||||
[title, desc, messages]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<TwoColumnsWithCards>
|
||||
{card}
|
||||
{inputs}
|
||||
</TwoColumnsWithCards>
|
||||
{controls}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// TODO: consolidate with FlaggableElement
|
||||
interface LabelSliderGroupProps {
|
||||
labelIDs: Array<string>;
|
||||
onChange: (sliderValues: number[]) => unknown;
|
||||
}
|
||||
|
||||
export const LabelSliderGroup = ({ labelIDs, onChange }: LabelSliderGroupProps) => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>(Array.from({ length: labelIDs.length }).map(() => 0));
|
||||
|
||||
useEffect(() => {
|
||||
onChange(sliderValues);
|
||||
}, [sliderValues, onChange]);
|
||||
|
||||
return (
|
||||
<Grid templateColumns="auto 1fr" rowGap={1} columnGap={3}>
|
||||
{labelIDs.map((labelId, idx) => (
|
||||
<CheckboxSliderItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
sliderValue={sliderValues[idx]}
|
||||
sliderHandler={(sliderValue) => {
|
||||
const newState = sliderValues.slice();
|
||||
newState[idx] = sliderValue;
|
||||
setSliderValues(newState);
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
);
|
||||
};
|
||||
|
||||
function CheckboxSliderItem(props: {
|
||||
labelId: string;
|
||||
sliderValue: number;
|
||||
sliderHandler: (newVal: number) => unknown;
|
||||
}) {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const labelTextClass = colorMode === "light" ? `text-${colors.light.text}` : `text-${colors.dark.text}`;
|
||||
|
||||
return (
|
||||
<>
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
</label>
|
||||
<Slider defaultValue={0} onChangeEnd={(val) => props.sliderHandler(val / 100)}>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,10 +1,25 @@
|
||||
import { CreateTask } from "./CreateTask";
|
||||
import { EvaluateTask } from "./EvaluateTask";
|
||||
import { TaskCategory, TaskTypes } from "./TaskTypes";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import poster from "src/lib/poster";
|
||||
|
||||
export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
const task = tasks[0].task;
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", poster, {
|
||||
onSuccess: async () => {
|
||||
mutate();
|
||||
},
|
||||
});
|
||||
|
||||
const rejectTask = (task: { id: string }, reason: string) => {
|
||||
sendRejection({
|
||||
id: task.id,
|
||||
reason,
|
||||
});
|
||||
};
|
||||
|
||||
function taskTypeComponent(type) {
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === type);
|
||||
const category = taskType.category;
|
||||
@@ -14,13 +29,22 @@ export const Task = ({ tasks, trigger, mutate, mainBgClasses }) => {
|
||||
<CreateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
mutate={mutate}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
taskType={taskType}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
case TaskCategory.Evaluate:
|
||||
return <EvaluateTask tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />;
|
||||
return (
|
||||
<EvaluateTask
|
||||
tasks={tasks}
|
||||
trigger={trigger}
|
||||
onSkipTask={rejectTask}
|
||||
onNextTask={mutate}
|
||||
mainBgClasses={mainBgClasses}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
export enum TaskCategory {
|
||||
Create,
|
||||
Evaluate,
|
||||
Create = "Create",
|
||||
Evaluate = "Evaluate",
|
||||
Label = "Label",
|
||||
}
|
||||
|
||||
export const TaskTypes = [
|
||||
export interface TaskType {
|
||||
label: string;
|
||||
desc: string;
|
||||
category: TaskCategory;
|
||||
pathname: string;
|
||||
type: string;
|
||||
overview?: string;
|
||||
instruction?: string;
|
||||
}
|
||||
|
||||
export const TaskTypes: TaskType[] = [
|
||||
// create
|
||||
{
|
||||
label: "Create Initial Prompts",
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
@@ -31,6 +43,7 @@ export const TaskTypes = [
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the assistant`s reply",
|
||||
},
|
||||
// evaluate
|
||||
{
|
||||
label: "Rank User Replies",
|
||||
category: TaskCategory.Evaluate,
|
||||
@@ -52,4 +65,26 @@ export const TaskTypes = [
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
type: "rank_initial_prompts",
|
||||
},
|
||||
// label
|
||||
{
|
||||
label: "Label Initial Prompt",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
type: "label_initial_prompt",
|
||||
},
|
||||
{
|
||||
label: "Label Prompter Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
type: "label_prompter_reply",
|
||||
},
|
||||
{
|
||||
label: "Label Assistant Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
type: "label_assistant_reply",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -1,4 +1,18 @@
|
||||
import { Table, TableCaption, TableContainer, Tbody, Td, Th, Thead, Tr } from "@chakra-ui/react";
|
||||
import {
|
||||
Button,
|
||||
Flex,
|
||||
Spacer,
|
||||
Stack,
|
||||
Table,
|
||||
TableCaption,
|
||||
TableContainer,
|
||||
Tbody,
|
||||
Td,
|
||||
Th,
|
||||
Thead,
|
||||
Tr,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
@@ -7,37 +21,60 @@ import useSWR from "swr";
|
||||
* Fetches users from the users api route and then presents them in a simple Chakra table.
|
||||
*/
|
||||
const UsersCell = () => {
|
||||
// Fetch and save the users.
|
||||
const [pageIndex, setPageIndex] = useState(0);
|
||||
const [users, setUsers] = useState([]);
|
||||
const { isLoading } = useSWR("/api/admin/users", fetcher, {
|
||||
|
||||
// Fetch and save the users.
|
||||
// This follows useSWR's recommendation for simple pagination:
|
||||
// https://swr.vercel.app/docs/pagination#when-to-use-useswr
|
||||
useSWR(`/api/admin/users?pageIndex=${pageIndex}`, fetcher, {
|
||||
onSuccess: setUsers,
|
||||
});
|
||||
|
||||
const toPreviousPage = () => {
|
||||
setPageIndex(Math.max(0, pageIndex - 1));
|
||||
};
|
||||
|
||||
const toNextPage = () => {
|
||||
setPageIndex(pageIndex + 1);
|
||||
};
|
||||
|
||||
// Present users in a naive table.
|
||||
return (
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>Users</TableCaption>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Id</Th>
|
||||
<Th>Email</Th>
|
||||
<Th>Name</Th>
|
||||
<Th>Role</Th>
|
||||
</Tr>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{users.map((user, index) => (
|
||||
<Tr key={index}>
|
||||
<Td>{user.id}</Td>
|
||||
<Td>{user.email}</Td>
|
||||
<Td>{user.name}</Td>
|
||||
<Td>{user.role}</Td>
|
||||
<Stack>
|
||||
<Flex p="2">
|
||||
<Button onClick={toPreviousPage}>Previous</Button>
|
||||
<Spacer />
|
||||
<Button onClick={toNextPage}>Next</Button>
|
||||
</Flex>
|
||||
<TableContainer>
|
||||
<Table variant="simple">
|
||||
<TableCaption>Users</TableCaption>
|
||||
<Thead>
|
||||
<Tr>
|
||||
<Th>Id</Th>
|
||||
<Th>Email</Th>
|
||||
<Th>Name</Th>
|
||||
<Th>Role</Th>
|
||||
<Th>Update</Th>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Thead>
|
||||
<Tbody>
|
||||
{users.map((user, index) => (
|
||||
<Tr key={index}>
|
||||
<Td>{user.id}</Td>
|
||||
<Td>{user.email}</Td>
|
||||
<Td>{user.name}</Td>
|
||||
<Td>{user.role}</Td>
|
||||
<Td>
|
||||
<Link href={`/admin/manage_user/${user.id}`}>Manage</Link>
|
||||
</Td>
|
||||
</Tr>
|
||||
))}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
</Stack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
const flags = [{ name: "flagTest", isActive: false }];
|
||||
|
||||
export default flags;
|
||||
@@ -0,0 +1,9 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface CreateInitialPromptTask {
|
||||
id: string;
|
||||
type: "initial_prompt";
|
||||
hint: string;
|
||||
}
|
||||
|
||||
export const useCreateInitialPrompt = () => useGenericTaskAPI<CreateInitialPromptTask>("initial_prompt");
|
||||
@@ -0,0 +1,24 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface BaseCreateReplyTask {
|
||||
id: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CreateAssistantReplyTask extends BaseCreateReplyTask {
|
||||
type: "assistant_reply";
|
||||
}
|
||||
|
||||
export interface CreatePrompterReplyTask extends BaseCreateReplyTask {
|
||||
type: "prompter_reply";
|
||||
}
|
||||
|
||||
export const useCreateAssistantReply = () => useGenericTaskAPI<CreateAssistantReplyTask>("assistant_reply");
|
||||
|
||||
export const useCreatePrompterReply = () => useGenericTaskAPI<CreatePrompterReplyTask>("prompter_reply");
|
||||
@@ -0,0 +1,9 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface RankInitialPromptsTask {
|
||||
id: string;
|
||||
type: "rank_initial_prompts";
|
||||
prompts: string[];
|
||||
}
|
||||
|
||||
export const useRankInitialPromptsTask = () => useGenericTaskAPI<RankInitialPromptsTask>("rank_initial_prompts");
|
||||
@@ -0,0 +1,25 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
interface BaseRankRepliesTask {
|
||||
id: string;
|
||||
replies: string[];
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
interface RankAssistantRepliesTask extends BaseRankRepliesTask {
|
||||
type: "rank_assistant_replies";
|
||||
}
|
||||
|
||||
interface RankPrompterRepliesTask extends BaseRankRepliesTask {
|
||||
type: "rank_prompter_replies";
|
||||
}
|
||||
|
||||
export const useRankAssistantRepliesTask = () => useGenericTaskAPI<RankAssistantRepliesTask>("rank_assistant_replies");
|
||||
|
||||
export const useRankPrompterRepliesTask = () => useGenericTaskAPI<RankPrompterRepliesTask>("rank_prompter_replies");
|
||||
@@ -0,0 +1,22 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelAssistantReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_assistant_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelAssistantReplyTaskResponse = TaskResponse<LabelAssistantReplyTask>;
|
||||
|
||||
export const useLabelAssistantReplyTask = () =>
|
||||
useLabelingTask<LabelAssistantReplyTask>(LabelingTaskType.label_assistant_reply);
|
||||
@@ -0,0 +1,15 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelInitialPromptTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_initial_prompt;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
prompt: string;
|
||||
}
|
||||
|
||||
export type LabelInitialPromptTaskResponse = TaskResponse<LabelInitialPromptTask>;
|
||||
|
||||
export const useLabelInitialPromptTask = () =>
|
||||
useLabelingTask<LabelInitialPromptTask>(LabelingTaskType.label_initial_prompt);
|
||||
@@ -0,0 +1,22 @@
|
||||
import { TaskResponse } from "../useGenericTaskAPI";
|
||||
import { LabelingTaskType, useLabelingTask } from "./useLabelingTask";
|
||||
|
||||
export interface LabelPrompterReplyTask {
|
||||
id: string;
|
||||
type: LabelingTaskType.label_prompter_reply;
|
||||
message_id: string;
|
||||
valid_labels: string[];
|
||||
reply: string;
|
||||
conversation: {
|
||||
messages: Array<{
|
||||
text: string;
|
||||
is_assistant: boolean;
|
||||
message_id: string;
|
||||
}>;
|
||||
};
|
||||
}
|
||||
|
||||
export type LabelPrompterReplyTaskResponse = TaskResponse<LabelPrompterReplyTask>;
|
||||
|
||||
export const useLabelPrompterReplyTask = () =>
|
||||
useLabelingTask<LabelPrompterReplyTask>(LabelingTaskType.label_prompter_reply);
|
||||
@@ -0,0 +1,20 @@
|
||||
import { useGenericTaskAPI } from "../useGenericTaskAPI";
|
||||
|
||||
export const enum LabelingTaskType {
|
||||
label_initial_prompt = "label_initial_prompt",
|
||||
label_prompter_reply = "label_prompter_reply",
|
||||
label_assistant_reply = "label_assistant_reply",
|
||||
}
|
||||
|
||||
export const useLabelingTask = <TaskType>(endpoint: LabelingTaskType) => {
|
||||
const { tasks, isLoading, trigger, reset, error } = useGenericTaskAPI<TaskType>(endpoint);
|
||||
|
||||
const submit = (id: string, message_id: string, text: string, validLabels: string[], labelWeights: number[]) => {
|
||||
console.assert(validLabels.length === labelWeights.length);
|
||||
const labels = Object.fromEntries(validLabels.map((label, i) => [label, labelWeights[i]]));
|
||||
|
||||
return trigger({ id, update_type: "text_labels", content: { labels, text, message_id } });
|
||||
};
|
||||
|
||||
return { tasks, isLoading, submit, reset, error };
|
||||
};
|
||||
@@ -0,0 +1,38 @@
|
||||
import { useState } from "react";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
// TODO: type & centralize types for all tasks
|
||||
|
||||
export interface TaskResponse<TaskType> {
|
||||
id: string;
|
||||
userId: string;
|
||||
task: TaskType;
|
||||
}
|
||||
|
||||
export const useGenericTaskAPI = <TaskType,>(taskApiEndpoint: string) => {
|
||||
type ConcreteTaskResponse = TaskResponse<TaskType>;
|
||||
|
||||
const [tasks, setTasks] = useState<ConcreteTaskResponse[]>([]);
|
||||
|
||||
const { isLoading, mutate, error } = useSWRImmutable<ConcreteTaskResponse>(
|
||||
"/api/new_task/" + taskApiEndpoint,
|
||||
fetcher,
|
||||
{
|
||||
onSuccess: (data) => setTasks([data]),
|
||||
revalidateOnMount: true,
|
||||
dedupingInterval: 500,
|
||||
}
|
||||
);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (response) => {
|
||||
const newTask: ConcreteTaskResponse = await response.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
return { tasks, isLoading, trigger, error, reset: mutate };
|
||||
};
|
||||
@@ -0,0 +1,19 @@
|
||||
import type { NextApiRequest, NextApiResponse } from "next";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
|
||||
/**
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
* role before running the handler. Returns a 403 otherwise.
|
||||
*/
|
||||
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role !== role) {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
return handler(req, res);
|
||||
};
|
||||
};
|
||||
|
||||
export default withRole;
|
||||
@@ -42,7 +42,7 @@ export class OasstApiClient {
|
||||
} catch (e) {
|
||||
throw new OasstError(errorText, 0, resp.status);
|
||||
}
|
||||
throw new OasstError(error.message, error.error_code, resp.status);
|
||||
throw new OasstError(error.message ?? error, error.error_code, resp.status);
|
||||
}
|
||||
|
||||
return await resp.json();
|
||||
@@ -68,6 +68,12 @@ export class OasstApiClient {
|
||||
});
|
||||
}
|
||||
|
||||
async nackTask(taskId: string, reason: string): Promise<void> {
|
||||
return this.post(`/api/v1/tasks/${taskId}/nack`, {
|
||||
reason,
|
||||
});
|
||||
}
|
||||
|
||||
// TODO return a strongly typed Task?
|
||||
// This method is used to record interaction with task while fetching next task.
|
||||
// This is a raw Json type, so we can't use it to strongly type the task.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
export { default } from "next-auth/middleware";
|
||||
|
||||
/**
|
||||
* Guards all pages under `/grading` and redirects them to the sign in page.
|
||||
* Guards these pages and redirects them to the sign in page.
|
||||
*/
|
||||
export const config = {
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/account/:path*", "/dashboard"],
|
||||
matcher: ["/create/:path*", "/evaluate/:path*", "/label/:path*", "/account/:path*", "/dashboard", "/admin/:path*"],
|
||||
};
|
||||
|
||||
@@ -3,7 +3,9 @@ import "focus-visible";
|
||||
|
||||
import type { AppProps } from "next/app";
|
||||
import { SessionProvider } from "next-auth/react";
|
||||
import { FlagsProvider } from "react-feature-flags";
|
||||
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
|
||||
import flags from "src/flags";
|
||||
|
||||
import { Chakra, getServerSideProps } from "../styles/Chakra";
|
||||
|
||||
@@ -16,9 +18,11 @@ function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: App
|
||||
const page = getLayout(<Component {...pageProps} />);
|
||||
|
||||
return (
|
||||
<Chakra cookies={cookies}>
|
||||
<SessionProvider session={session}>{page}</SessionProvider>
|
||||
</Chakra>
|
||||
<FlagsProvider value={flags}>
|
||||
<Chakra cookies={cookies}>
|
||||
<SessionProvider session={session}>{page}</SessionProvider>
|
||||
</Chakra>
|
||||
</FlagsProvider>
|
||||
);
|
||||
}
|
||||
export { getServerSideProps };
|
||||
|
||||
@@ -2,7 +2,7 @@ import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
import UsersCell from "src/components/UsersCell";
|
||||
|
||||
/**
|
||||
@@ -26,10 +26,8 @@ const AdminIndex = () => {
|
||||
return;
|
||||
}
|
||||
router.push("/");
|
||||
}, [session, status]);
|
||||
}, [router, session, status]);
|
||||
|
||||
// Show the final page.
|
||||
// TODO(#237): Display a component that fetches actual user data.
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -44,6 +42,6 @@ const AdminIndex = () => {
|
||||
);
|
||||
};
|
||||
|
||||
AdminIndex.getLayout = getTransparentHeaderLayout;
|
||||
AdminIndex.getLayout = getAdminLayout;
|
||||
|
||||
export default AdminIndex;
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
import { Button, Container, FormControl, FormLabel, Input, Select, useToast } from "@chakra-ui/react";
|
||||
import { Field, Form, Formik } from "formik";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
import poster from "src/lib/poster";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
const ManageUser = ({ user }) => {
|
||||
const toast = useToast();
|
||||
const router = useRouter();
|
||||
const { data: session, status } = useSession();
|
||||
|
||||
// Check when the user session is loaded and re-route if the user is not an
|
||||
// admin. This follows the suggestion by NextJS for handling private pages:
|
||||
// https://nextjs.org/docs/api-reference/next/router#usage
|
||||
//
|
||||
// All admin pages should use the same check and routing steps.
|
||||
useEffect(() => {
|
||||
if (status === "loading") {
|
||||
return;
|
||||
}
|
||||
if (session?.user?.role === "admin") {
|
||||
return;
|
||||
}
|
||||
router.push("/");
|
||||
}, [router, session, status]);
|
||||
|
||||
// Trigger to let us update the user's role. Triggers a toast when complete.
|
||||
const { trigger } = useSWRMutation("/api/admin/update_user", poster, {
|
||||
onSuccess: () => {
|
||||
toast({
|
||||
title: "User Role Updated",
|
||||
status: "success",
|
||||
duration: 1000,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
onError: () => {
|
||||
toast({
|
||||
title: "User Role update failed",
|
||||
status: "error",
|
||||
duration: 1000,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Manage Users - Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
</Head>
|
||||
<Container className="oa-basic-theme">
|
||||
<Formik
|
||||
initialValues={user}
|
||||
onSubmit={(values) => {
|
||||
trigger(values);
|
||||
}}
|
||||
>
|
||||
<Form>
|
||||
<Field name="id" type="hidden" />
|
||||
<Field name="name">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Username</FormLabel>
|
||||
<Input {...field} isDisabled />
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
<Field name="email">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Email</FormLabel>
|
||||
<Input {...field} isDisabled />
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
|
||||
<Field name="role">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Role</FormLabel>
|
||||
<Select {...field}>
|
||||
<option value="banned">Banned</option>
|
||||
<option value="general">General</option>
|
||||
<option value="admin">Admin</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
<Button mt={4} type="submit">
|
||||
Update
|
||||
</Button>
|
||||
</Form>
|
||||
</Formik>
|
||||
</Container>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetch the user's data on the server side when rendering.
|
||||
*/
|
||||
export async function getServerSideProps({ query }) {
|
||||
const user = await prisma.user.findUnique({
|
||||
where: { id: query.id },
|
||||
select: {
|
||||
id: true,
|
||||
name: true,
|
||||
email: true,
|
||||
role: true,
|
||||
},
|
||||
});
|
||||
return {
|
||||
props: {
|
||||
user,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
ManageUser.getLayout = getAdminLayout;
|
||||
|
||||
export default ManageUser;
|
||||
@@ -0,0 +1,22 @@
|
||||
import withRole from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Update's the user's data in the database. Accessible only to admins.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
const { id, role } = JSON.parse(req.body);
|
||||
|
||||
await prisma.user.update({
|
||||
where: {
|
||||
id,
|
||||
},
|
||||
data: {
|
||||
role,
|
||||
},
|
||||
});
|
||||
|
||||
res.status(200).end();
|
||||
});
|
||||
|
||||
export default handler;
|
||||
@@ -1,31 +1,34 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import client from "src/lib/prismadb";
|
||||
import withRole from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
// The number of users to fetch in any request.
|
||||
const PAGE_SIZE = 20;
|
||||
|
||||
/**
|
||||
* Returns a list of user results from the database when the requesting user is
|
||||
* a logged in admin.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered or if the user isn't an admin.
|
||||
if (!token || token.role !== "admin") {
|
||||
res.status(403).end();
|
||||
return;
|
||||
}
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
// Figure out the pagination index and skip that number of users.
|
||||
//
|
||||
// Note: with Prisma this isn't the most efficient but it's the only possible
|
||||
// option with cuid based User IDs.
|
||||
const { pageIndex } = req.query;
|
||||
const skip = parseInt(pageIndex as string) * PAGE_SIZE || 0;
|
||||
|
||||
// Fetch 20 users.
|
||||
const users = await client.user.findMany({
|
||||
const users = await prisma.user.findMany({
|
||||
select: {
|
||||
id: true,
|
||||
role: true,
|
||||
name: true,
|
||||
email: true,
|
||||
},
|
||||
take: 20,
|
||||
skip,
|
||||
take: PAGE_SIZE,
|
||||
});
|
||||
|
||||
res.status(200).json(users);
|
||||
};
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -36,9 +36,6 @@ const handler = async (req, res) => {
|
||||
},
|
||||
});
|
||||
|
||||
// Update the backend with our Task ID
|
||||
await oasstApiClient.ackTask(task.id, registeredTask.id);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json(registeredTask);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id: frontendId, reason } = await JSON.parse(req.body);
|
||||
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
|
||||
// Update the backend with the rejection
|
||||
await oasstApiClient.nackTask(id, reason);
|
||||
|
||||
// Send the results to the client.
|
||||
res.status(200).json({});
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -0,0 +1,41 @@
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Sets the Label in the Backend.
|
||||
*
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
|
||||
// Return nothing if the user isn't registered.
|
||||
if (!token) {
|
||||
res.status(401).end();
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse out the local message_id, task ID and the interaction contents.
|
||||
const { message_id, post_id, label_map, text } = await JSON.parse(req.body);
|
||||
|
||||
const interactionRes = await fetch(`${process.env.FASTAPI_URL}/api/v1/text_labels`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"X-API-Key": process.env.FASTAPI_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
type: "text_labels",
|
||||
message_id: message_id,
|
||||
labels: label_map,
|
||||
text: text,
|
||||
user: {
|
||||
id: token.sub,
|
||||
display_name: token.name || token.email,
|
||||
auth_method: "local",
|
||||
},
|
||||
}),
|
||||
});
|
||||
res.status(interactionRes.status).end();
|
||||
};
|
||||
|
||||
export default handler;
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Prisma } from "@prisma/client";
|
||||
import { getToken } from "next-auth/jwt";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
@@ -6,9 +7,11 @@ import prisma from "src/lib/prismadb";
|
||||
* Stores the task interaction with the Task Backend and then returns the next task generated.
|
||||
*
|
||||
* This implicity does a few things:
|
||||
* 1) Stores the answer with the Task Backend.
|
||||
* 2) Records the new task in our local database.
|
||||
* 3) Returns the newly created task to the client.
|
||||
* 1) Records the users answer in our local database.
|
||||
* 2) Accepts the task.
|
||||
* 3) Sends the users answer to the Task Backend.
|
||||
* 4) Records the new task in our local database.
|
||||
* 5) Returns the newly created task to the client.
|
||||
*/
|
||||
const handler = async (req, res) => {
|
||||
const token = await getToken({ req });
|
||||
@@ -20,7 +23,13 @@ const handler = async (req, res) => {
|
||||
}
|
||||
|
||||
// Parse out the local task ID and the interaction contents.
|
||||
const { id, content, update_type } = await JSON.parse(req.body);
|
||||
const { id: frontendId, content, update_type } = await JSON.parse(req.body);
|
||||
|
||||
// Accept the task so that we can complete it, this will probably go away soon.
|
||||
const registeredTask = await prisma.registeredTask.findUniqueOrThrow({ where: { id: frontendId } });
|
||||
const task = registeredTask.task as Prisma.JsonObject;
|
||||
const id = task.id as string;
|
||||
await oasstApiClient.ackTask(id, registeredTask.id);
|
||||
|
||||
// Log the interaction locally to create our user_post_id needed by the Task
|
||||
// Backend.
|
||||
@@ -29,13 +38,18 @@ const handler = async (req, res) => {
|
||||
content,
|
||||
task: {
|
||||
connect: {
|
||||
id,
|
||||
id: frontendId,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const newTask = await oasstApiClient.interactTask(update_type, id, interaction.id, content, token);
|
||||
let newTask;
|
||||
try {
|
||||
newTask = await oasstApiClient.interactTask(update_type, frontendId, interaction.id, content, token);
|
||||
} catch (err) {
|
||||
return res.status(500).json(err);
|
||||
}
|
||||
|
||||
// Stores the new task with our database.
|
||||
const newRegisteredTask = await prisma.registeredTask.create({
|
||||
|
||||
@@ -2,17 +2,59 @@ import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import React, { useRef } from "react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
|
||||
export type SignInErrorTypes =
|
||||
| "Signin"
|
||||
| "OAuthSignin"
|
||||
| "OAuthCallback"
|
||||
| "OAuthCreateAccount"
|
||||
| "EmailCreateAccount"
|
||||
| "Callback"
|
||||
| "OAuthAccountNotLinked"
|
||||
| "EmailSignin"
|
||||
| "CredentialsSignin"
|
||||
| "SessionRequired"
|
||||
| "default";
|
||||
|
||||
const errorMessages: Record<SignInErrorTypes, string> = {
|
||||
Signin: "Try signing in with a different account.",
|
||||
OAuthSignin: "Try signing in with a different account.",
|
||||
OAuthCallback: "Try signing in with the same account you used originally.",
|
||||
OAuthCreateAccount: "Try signing in with a different account.",
|
||||
EmailCreateAccount: "Try signing in with a different account.",
|
||||
Callback: "Try signing in with a different account.",
|
||||
OAuthAccountNotLinked: "To confirm your identity, sign in with the same account you used originally.",
|
||||
EmailSignin: "The e-mail could not be sent.",
|
||||
CredentialsSignin: "Sign in failed. Check the details you provided are correct.",
|
||||
SessionRequired: "Please sign in to access this page.",
|
||||
default: "Unable to sign in.",
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function Signin({ csrfToken, providers }) {
|
||||
const router = useRouter();
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
const [error, setError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const err = router?.query?.error;
|
||||
if (err) {
|
||||
if (typeof err === "string") {
|
||||
setError(errorMessages[err]);
|
||||
} else {
|
||||
setError(errorMessages[err[0]]);
|
||||
}
|
||||
}
|
||||
}, [router]);
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
|
||||
@@ -110,6 +152,11 @@ function Signin({ csrfToken, providers }) {
|
||||
</Link>
|
||||
.
|
||||
</div>
|
||||
{error && (
|
||||
<div className="text-center mt-8">
|
||||
<p className="text-orange-600">Error: {error}</p>
|
||||
</div>
|
||||
)}
|
||||
</AuthLayout>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { getCsrfToken, getProviders } from "next-auth/react";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
|
||||
export default function Verify() {
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Sign Up - Open Assistant</title>
|
||||
<meta name="Sign Up" content="Sign up to access Open Assistant" />
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
</AuthLayout>
|
||||
<div className={`flex h-full justify-center items-center ${bgColorClass}`}>
|
||||
<div className={bgColorClass}>
|
||||
<h1 className="text-lg">A sign-in link has been sent to your email address.</h1>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,35 +1,12 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useCreateAssistantReply } from "src/hooks/tasks/create/useCreateReply";
|
||||
|
||||
const AssistantReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/assistant_reply ", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
const { tasks, isLoading, reset, trigger } = useCreateAssistantReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -38,7 +15,7 @@ const AssistantReply = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
@@ -48,7 +25,7 @@ const AssistantReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,35 +1,12 @@
|
||||
import { Container } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useCreateInitialPrompt } from "src/hooks/tasks/create/useCreateInitialPrompt";
|
||||
|
||||
const InitialPrompt = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/initial_prompt ", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
const { tasks, isLoading, reset, trigger } = useCreateInitialPrompt();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -38,7 +15,7 @@ const InitialPrompt = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
@@ -48,7 +25,7 @@ const InitialPrompt = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -87,7 +87,12 @@ const SummarizeStory = () => {
|
||||
</>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,34 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useCreatePrompterReply } from "src/hooks/tasks/create/useCreateReply";
|
||||
|
||||
const UserReply = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/prompter_reply", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
const { tasks, isLoading, reset, trigger } = useCreatePrompterReply();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -37,14 +15,8 @@ const UserReply = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center text-gray-800">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -53,7 +25,7 @@ const UserReply = () => {
|
||||
<title>Reply as Assistant</title>
|
||||
<meta name="description" content="Reply as Assistant." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,34 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useRankAssistantRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
|
||||
|
||||
const RankAssistantReplies = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_assistant_replies", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
const { tasks, isLoading, reset, trigger } = useRankAssistantRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -37,14 +15,8 @@ const RankAssistantReplies = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -53,7 +25,7 @@ const RankAssistantReplies = () => {
|
||||
<title>Rank Assistant Replies</title>
|
||||
<meta name="description" content="Rank Assistant Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,34 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useRankInitialPromptsTask } from "src/hooks/tasks/evaluate/useRankInitialPrompts";
|
||||
|
||||
const RankInitialPrompts = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_initial_prompts", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
const { tasks, isLoading, reset, trigger } = useRankInitialPromptsTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -37,14 +15,8 @@ const RankInitialPrompts = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -53,7 +25,7 @@ const RankInitialPrompts = () => {
|
||||
<title>Rank Initial Prompts</title>
|
||||
<meta name="description" content="Rank initial prompts." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,34 +1,12 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Container } from "src/components/Container";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Task } from "src/components/Tasks/Task";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import poster from "src/lib/poster";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
import { useRankPrompterRepliesTask } from "src/hooks/tasks/evaluate/useRankReplies";
|
||||
|
||||
const RankUserReplies = () => {
|
||||
const [tasks, setTasks] = useState([]);
|
||||
|
||||
const { isLoading, mutate } = useSWRImmutable("/api/new_task/rank_prompter_replies", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setTasks([data]);
|
||||
},
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (tasks.length == 0) {
|
||||
mutate();
|
||||
}
|
||||
}, [tasks]);
|
||||
|
||||
const { trigger } = useSWRMutation("/api/update_task", poster, {
|
||||
onSuccess: async (data) => {
|
||||
const newTask = await data.json();
|
||||
setTasks((oldTasks) => [...oldTasks, newTask]);
|
||||
},
|
||||
});
|
||||
const { tasks, isLoading, reset, trigger } = useRankPrompterRepliesTask();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const mainBgClasses = colorMode === "light" ? "bg-slate-300 text-gray-800" : "bg-slate-900 text-white";
|
||||
@@ -37,14 +15,8 @@ const RankUserReplies = () => {
|
||||
return <LoadingScreen text="Loading..." />;
|
||||
}
|
||||
|
||||
if (tasks.length == 0) {
|
||||
return (
|
||||
<div className={`p-12 ${mainBgClasses}`}>
|
||||
<div className="flex h-full">
|
||||
<div className="text-xl font-bold mx-auto my-auto">No tasks found...</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
if (tasks.length === 0) {
|
||||
return <Container className="p-6 text-center">No tasks found...</Container>;
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -53,7 +25,7 @@ const RankUserReplies = () => {
|
||||
<title>Rank User Replies</title>
|
||||
<meta name="description" content="Rank User Replies." />
|
||||
</Head>
|
||||
<Task tasks={tasks} trigger={trigger} mutate={mutate} mainBgClasses={mainBgClasses} />
|
||||
<Task tasks={tasks} trigger={trigger} mutate={reset} mainBgClasses={mainBgClasses} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -102,7 +102,12 @@ const RateSummary = () => {
|
||||
</section>
|
||||
</TwoColumnsWithCards>
|
||||
|
||||
<TaskControls tasks={tasks} onSubmitResponse={submitResponse} onSkip={fetchNextTask} />
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSubmitResponse={submitResponse}
|
||||
onSkipTask={fetchNextTask}
|
||||
onNextTask={fetchNextTask}
|
||||
/>
|
||||
</main>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelAssistantReplyTaskResponse,
|
||||
useLabelAssistantReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelAssistantReply";
|
||||
|
||||
const LabelAssistantReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelAssistantReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: true, message_id: task.message_id },
|
||||
];
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Assistant Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelAssistantReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelAssistantReply;
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelInitialPromptTaskResponse,
|
||||
useLabelInitialPromptTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelInitialPrompt";
|
||||
|
||||
const LabelInitialPrompt = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelInitialPromptTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Initial Prompt"
|
||||
desc="Provide labels for the following prompt"
|
||||
messages={<MessageView text={task.prompt} is_assistant message_id={task.message_id} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelInitialPromptTaskResponse) =>
|
||||
submit(id, task.message_id, task.prompt, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelInitialPrompt;
|
||||
@@ -0,0 +1,47 @@
|
||||
import { useState } from "react";
|
||||
import { LoadingScreen } from "src/components/Loading/LoadingScreen";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TaskControls } from "src/components/Survey/TaskControls";
|
||||
import { LabelSliderGroup, LabelTask } from "src/components/Tasks/LabelTask";
|
||||
import {
|
||||
LabelPrompterReplyTaskResponse,
|
||||
useLabelPrompterReplyTask,
|
||||
} from "src/hooks/tasks/labeling/useLabelPrompterReply";
|
||||
|
||||
const LabelPrompterReply = () => {
|
||||
const [sliderValues, setSliderValues] = useState<number[]>([]);
|
||||
|
||||
const { tasks, isLoading, submit, reset } = useLabelPrompterReplyTask();
|
||||
|
||||
if (isLoading || tasks.length === 0) {
|
||||
return <LoadingScreen />;
|
||||
}
|
||||
|
||||
const task = tasks[0].task;
|
||||
const messages: Message[] = [
|
||||
...task.conversation.messages,
|
||||
{ text: task.reply, is_assistant: false, message_id: task.message_id },
|
||||
];
|
||||
|
||||
return (
|
||||
<LabelTask
|
||||
title="Label Prompter Reply"
|
||||
desc="Given the following discussion, provide labels for the final prompt"
|
||||
messages={<MessageTable messages={messages} />}
|
||||
inputs={<LabelSliderGroup labelIDs={task.valid_labels} onChange={setSliderValues} />}
|
||||
controls={
|
||||
<TaskControls
|
||||
tasks={tasks}
|
||||
onSkipTask={() => reset()}
|
||||
onNextTask={reset}
|
||||
onSubmitResponse={({ id, task }: LabelPrompterReplyTaskResponse) =>
|
||||
submit(id, task.message_id, task.reply, task.valid_labels, sliderValues)
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default LabelPrompterReply;
|
||||
@@ -2,6 +2,7 @@ import { Box, CircularProgress, SimpleGrid, Text, useColorModeValue } from "@cha
|
||||
import Head from "next/head";
|
||||
import { useEffect, useState } from "react";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { Message } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import fetcher from "src/lib/fetcher";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
@@ -10,29 +11,28 @@ const MessagesDashboard = () => {
|
||||
const boxBgColor = useColorModeValue("white", "gray.700");
|
||||
const boxAccentColor = useColorModeValue("gray.200", "gray.900");
|
||||
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [userMessages, setUserMessages] = useState([]);
|
||||
const [messages, setMessages] = useState<Message[]>(null);
|
||||
const [userMessages, setUserMessages] = useState<Message[]>(null);
|
||||
|
||||
const { isLoading: isLoadingAll, mutate: mutateAll } = useSWRImmutable("/api/messages", fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setMessages(data);
|
||||
},
|
||||
onSuccess: setMessages,
|
||||
});
|
||||
|
||||
const { isLoading: isLoadingUser, mutate: mutateUser } = useSWRImmutable(`/api/messages/user`, fetcher, {
|
||||
onSuccess: (data) => {
|
||||
setUserMessages(data);
|
||||
},
|
||||
onSuccess: setUserMessages,
|
||||
});
|
||||
|
||||
const receivedMessages = !isLoadingAll && Array.isArray(messages);
|
||||
const receivedUserMessages = !isLoadingUser && Array.isArray(userMessages);
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length == 0) {
|
||||
if (!receivedMessages) {
|
||||
mutateAll();
|
||||
}
|
||||
if (userMessages.length == 0) {
|
||||
if (!receivedUserMessages) {
|
||||
mutateUser();
|
||||
}
|
||||
}, [messages, userMessages]);
|
||||
}, [receivedMessages, mutateAll, receivedUserMessages, mutateUser]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -52,7 +52,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingAll ? <CircularProgress isIndeterminate /> : <MessageTable messages={messages} />}
|
||||
{receivedMessages ? <MessageTable messages={messages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
<Box>
|
||||
@@ -66,7 +66,7 @@ const MessagesDashboard = () => {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
{isLoadingUser ? <CircularProgress isIndeterminate /> : <MessageTable messages={userMessages} />}
|
||||
{receivedUserMessages ? <MessageTable messages={userMessages} /> : <CircularProgress isIndeterminate />}
|
||||
</Box>
|
||||
</Box>
|
||||
</SimpleGrid>
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import Head from "next/head";
|
||||
import { TaskOption } from "src/components/Dashboard";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
|
||||
const AllTasks = () => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>All Tasks - Open Assistant</title>
|
||||
<meta name="description" content="All tasks for Open Assistant." />
|
||||
</Head>
|
||||
<TaskOption />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
AllTasks.getLayout = (page) => getDashboardLayout(page);
|
||||
|
||||
export default AllTasks;
|
||||
Reference in New Issue
Block a user