mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
name: Deploy to dev machine
|
||||
name: Deploy to node
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
@@ -23,6 +23,7 @@ on:
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ${{ inputs.stack-name }}
|
||||
env:
|
||||
WEB_ADMIN_USERS: ${{ secrets.DEV_WEB_ADMIN_USERS }}
|
||||
WEB_DISCORD_CLIENT_ID: ${{ secrets.DEV_WEB_DISCORD_CLIENT_ID }}
|
||||
@@ -32,6 +33,9 @@ jobs:
|
||||
WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }}
|
||||
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
|
||||
WEB_NEXTAUTH_SECRET: ${{ secrets.DEV_WEB_NEXTAUTH_SECRET }}
|
||||
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
|
||||
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
|
||||
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
@@ -39,7 +43,7 @@ jobs:
|
||||
uses: dawidd6/action-ansible-playbook@v2
|
||||
with:
|
||||
# Required, playbook filepath
|
||||
playbook: deploy-dev.yaml
|
||||
playbook: deploy-to-node.yaml
|
||||
# Optional, directory where playbooks live
|
||||
directory: ansible
|
||||
# Optional, SSH private key
|
||||
@@ -49,4 +53,9 @@ jobs:
|
||||
[dev]
|
||||
dev01 ansible_host=${{secrets.DEV_NODE_IP}} ansible_connection=ssh ansible_user=web-team
|
||||
options: |
|
||||
--extra-vars "stack_name=${{inputs.stack-name}} image_tag=${{inputs.image-tag}} backend_port=${{inputs.backend-port}} website_port=${{inputs.website-port}}"
|
||||
--extra-vars "stack_name=${{inputs.stack-name}} \
|
||||
image_tag=${{inputs.image-tag}} \
|
||||
backend_port=${{inputs.backend-port}} \
|
||||
website_port=${{inputs.website-port}} \
|
||||
postgres_password=${{secrets.POSTGRES_PASSWORD}} \
|
||||
web_api_key=${{secrets.WEB_API_KEY}}"
|
||||
@@ -0,0 +1,16 @@
|
||||
name: Deploy to prod
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- production
|
||||
|
||||
jobs:
|
||||
deploy-to-prod:
|
||||
uses: ./.github/workflows/deploy-to-node.yaml
|
||||
secrets: inherit
|
||||
with:
|
||||
stack-name: production
|
||||
image-tag: ${{ vars.PROD_IMAGE_TAG }}
|
||||
backend-port: 8280
|
||||
website-port: 3200
|
||||
@@ -35,9 +35,9 @@ jobs:
|
||||
context: .
|
||||
dockerfile: docker/Dockerfile.discord-bot
|
||||
build-args: ""
|
||||
deploy-dev:
|
||||
deploy-to-node:
|
||||
needs: [build-backend, build-web, build-bot]
|
||||
uses: ./.github/workflows/deploy-dev.yaml
|
||||
uses: ./.github/workflows/deploy-to-node.yaml
|
||||
secrets: inherit
|
||||
with:
|
||||
stack-name: ${{ github.event_name == 'release' && 'staging' || 'dev' }}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ansible playbook to set up some docker containers
|
||||
|
||||
- name: Deploy to dev node
|
||||
- name: Deploy to node
|
||||
hosts: dev
|
||||
gather_facts: true
|
||||
vars:
|
||||
@@ -8,6 +8,8 @@
|
||||
image_tag: latest
|
||||
backend_port: 8080
|
||||
website_port: 3000
|
||||
postgres_password: postgres
|
||||
web_api_key: "1234"
|
||||
tasks:
|
||||
- name: Create network
|
||||
community.docker.docker_network:
|
||||
@@ -44,6 +46,14 @@
|
||||
volumes:
|
||||
- "./{{ stack_name }}/redis.conf:/usr/local/etc/redis/redis.conf"
|
||||
|
||||
- name: Create volumes for postgres
|
||||
community.docker.docker_volume:
|
||||
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
|
||||
state: present
|
||||
loop:
|
||||
- name: backend
|
||||
- name: web
|
||||
|
||||
- name: Create postgres containers
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
|
||||
@@ -54,8 +64,11 @@
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_PASSWORD: "{{ postgres_password }}"
|
||||
POSTGRES_DB: postgres
|
||||
volumes:
|
||||
- "oasst-{{ stack_name }}-postgres-{{ item.name
|
||||
}}:/var/lib/postgresql/data"
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||
interval: 2s
|
||||
@@ -65,6 +78,29 @@
|
||||
- name: backend
|
||||
- name: web
|
||||
|
||||
- name: Copy pgbackrest.conf to managed node
|
||||
ansible.builtin.copy:
|
||||
src: ./pgbackrest.conf
|
||||
dest: "./{{ stack_name }}/pgbackrest.conf"
|
||||
mode: 0644
|
||||
|
||||
- name: Create pgbackrest container
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-pgbackrest"
|
||||
image: woblerr/pgbackrest:2.43
|
||||
state: "{{ 'stopped' if stack_name == 'production' else 'absent' }}"
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
volumes:
|
||||
- "./{{ stack_name }}/pgbackrest.conf:/etc/pgbackrest/pgbackrest.conf"
|
||||
- "oasst-{{ stack_name }}-postgres-backend:/var/lib/postgresql/data"
|
||||
env:
|
||||
PGBACKREST_REPO1_S3_BUCKET:
|
||||
"{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
|
||||
PGBACKREST_REPO1_S3_KEY:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
|
||||
PGBACKREST_REPO1_S3_KEY_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
|
||||
|
||||
- name: Run the oasst oasst-backend
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-backend"
|
||||
@@ -76,15 +112,18 @@
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
env:
|
||||
POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend"
|
||||
POSTGRES_PASSWORD: "{{ postgres_password }}"
|
||||
REDIS_HOST: "oasst-{{ stack_name }}-redis"
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
DEBUG_USE_SEED_DATA:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
DEBUG_ALLOW_SELF_LABELING:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
OFFICIAL_WEB_API_KEY: "{{ web_api_key }}"
|
||||
ports:
|
||||
- "{{ backend_port }}:8080"
|
||||
|
||||
@@ -100,9 +139,9 @@
|
||||
env:
|
||||
ADMIN_USERS: "{{ lookup('ansible.builtin.env', 'WEB_ADMIN_USERS') }}"
|
||||
DATABASE_URL:
|
||||
"postgres://postgres:postgres@oasst-{{ stack_name
|
||||
"postgres://postgres:{{ postgres_password }}@oasst-{{ stack_name
|
||||
}}-postgres-web/postgres"
|
||||
DEBUG_LOGIN: "true"
|
||||
DEBUG_LOGIN: "{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
DISCORD_CLIENT_ID:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_DISCORD_CLIENT_ID') }}"
|
||||
DISCORD_CLIENT_SECRET:
|
||||
@@ -117,10 +156,12 @@
|
||||
EMAIL_SERVER_USER:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_EMAIL_SERVER_USER') }}"
|
||||
FASTAPI_URL: "http://oasst-{{ stack_name }}-backend:8080"
|
||||
FASTAPI_KEY: "1234"
|
||||
FASTAPI_KEY: "{{ web_api_key }}"
|
||||
NEXTAUTH_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}"
|
||||
NEXTAUTH_URL: http://web.{{ stack_name }}.open-assistant.io/
|
||||
NEXTAUTH_URL:
|
||||
"{{ 'https://open-assistant.io/' if stack_name == 'production' else
|
||||
('https://web.' + stack_name + '.open-assistant.io/') }}"
|
||||
ports:
|
||||
- "{{ website_port }}:3000"
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
@@ -0,0 +1,24 @@
|
||||
[oasst]
|
||||
pg1-path=/var/lib/postgresql/data
|
||||
|
||||
[global]
|
||||
repo1-retention-full=3 # keep last 3 backups
|
||||
repo1-type=s3
|
||||
repo1-path=/oasst-prod
|
||||
repo1-s3-region=us-east-1
|
||||
repo1-s3-endpoint=s3.amazonaws.com
|
||||
# repo1-s3-bucket=$S3_BUCKET_NAME
|
||||
# repo1-s3-key=$AWS_ACCESS_KEY
|
||||
# repo1-s3-key-secret=$AWS_SECRET_KEY
|
||||
|
||||
# Force a checkpoint to start backup immediately.
|
||||
start-fast=y
|
||||
# Use delta restore.
|
||||
delta=y
|
||||
|
||||
# Enable ZSTD compression.
|
||||
compress-type=zst
|
||||
compress-level=6
|
||||
|
||||
log-level-console=info
|
||||
log-level-file=debug
|
||||
+100
-81
@@ -11,7 +11,7 @@ import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_utils.tasks import repeat_every
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.deps import api_auth, create_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
from oasst_backend.api.v1.utils import prepare_conversation
|
||||
from oasst_backend.config import settings
|
||||
@@ -20,6 +20,7 @@ from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -76,6 +77,24 @@ if settings.UPDATE_ALEMBIC:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
|
||||
|
||||
if settings.OFFICIAL_WEB_API_KEY:
|
||||
|
||||
@app.on_event("startup")
|
||||
def create_official_web_api_client():
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
api_auth(settings.OFFICIAL_WEB_API_KEY, db=session)
|
||||
except OasstError:
|
||||
logger.info("Creating official web API client")
|
||||
create_api_client(
|
||||
session=session,
|
||||
api_key=settings.OFFICIAL_WEB_API_KEY,
|
||||
description="The official web client for the OASST backend.",
|
||||
frontend_type="web",
|
||||
trusted=True,
|
||||
)
|
||||
|
||||
|
||||
if settings.RATE_LIMIT:
|
||||
|
||||
@app.on_event("startup")
|
||||
@@ -102,7 +121,8 @@ if settings.RATE_LIMIT:
|
||||
if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def create_seed_data(session: Session):
|
||||
class DummyMessage(BaseModel):
|
||||
task_message_id: str
|
||||
user_message_id: str
|
||||
@@ -111,75 +131,78 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
role: str
|
||||
tree_state: Optional[message_tree_state.State]
|
||||
|
||||
if not settings.OFFICIAL_WEB_API_KEY:
|
||||
raise ValueError("Cannot use seed data without OFFICIAL_WEB_API_KEY")
|
||||
|
||||
try:
|
||||
logger.info("Seed data check began")
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(db, pr)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=session)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
ur = UserRepository(db=session, api_client=api_client)
|
||||
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(session, pr)
|
||||
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = prepare_conversation(conversation_messages)
|
||||
if msg.role == "assistant":
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
else:
|
||||
task = tr.store_task(
|
||||
protocol_schema.PrompterReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text,
|
||||
msg.task_message_id,
|
||||
msg.user_message_id,
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
|
||||
)
|
||||
db.commit()
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
for msg in dummy_messages:
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
session.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = prepare_conversation(conversation_messages)
|
||||
if msg.role == "assistant":
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
else:
|
||||
task = tr.store_task(
|
||||
protocol_schema.PrompterReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text,
|
||||
msg.task_message_id,
|
||||
msg.user_message_id,
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
|
||||
)
|
||||
session.flush()
|
||||
|
||||
logger.info("Seed data check completed")
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Seed data insertion failed")
|
||||
@@ -199,48 +222,44 @@ def ensure_tree_states():
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
|
||||
def update_leader_board_day() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_day(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
except Exception:
|
||||
logger.exception("Error during leaderboard update (daily)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
|
||||
def update_leader_board_week() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_week(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (weekly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
|
||||
def update_leader_board_month() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_month(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (monthly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
|
||||
def update_leader_board_total() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_total(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (total)")
|
||||
|
||||
|
||||
@@ -61,33 +61,11 @@ def create_api_client(
|
||||
return api_client
|
||||
|
||||
|
||||
def get_dummy_api_client(session: Session) -> ApiClient:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
DUMMY_API_KEY = "1234"
|
||||
api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first()
|
||||
if api_client is None:
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}")
|
||||
api_client = create_api_client(
|
||||
session=session,
|
||||
api_key=DUMMY_API_KEY,
|
||||
description="Dummy api key for debugging",
|
||||
trusted=True,
|
||||
frontend_type="Test frontend",
|
||||
)
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
return api_client
|
||||
|
||||
|
||||
def api_auth(
|
||||
api_key: APIKey,
|
||||
db: Session,
|
||||
) -> ApiClient:
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY:
|
||||
return get_dummy_api_client(db)
|
||||
|
||||
if api_key:
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
if api_client is not None and api_client.enabled:
|
||||
return api_client
|
||||
|
||||
@@ -19,14 +19,14 @@ router = APIRouter()
|
||||
def get_users(
|
||||
api_client_id: Optional[UUID] = None,
|
||||
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
||||
gte: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gte=gte, lt=lt, auth_method=auth_method)
|
||||
users = ur.query_users(api_client_id=api_client_id, limit=max_count, gt=gt, lt=lt, auth_method=auth_method)
|
||||
return [u.to_protocol_frontend_user() for u in users]
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
@@ -48,6 +48,27 @@ def request_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int])
|
||||
def tasks_availability(
|
||||
*,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=user)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.determine_task_availability()
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Task availability query failed.")
|
||||
raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED)
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
|
||||
@@ -55,10 +55,13 @@ class TreeManagerConfiguration(BaseModel):
|
||||
mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam]
|
||||
"""Mandatory labels in text-labeling tasks for prompter replies."""
|
||||
|
||||
rank_prompter_replies: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = "open-assistant backend"
|
||||
API_V1_STR: str = "/api/v1"
|
||||
OFFICIAL_WEB_API_KEY: str = "1234"
|
||||
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
@@ -66,13 +69,12 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PASSWORD: str = "postgres"
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
DATABASE_MAX_TX_RETRY_COUNT: int = 3
|
||||
|
||||
RATE_LIMIT: bool = True
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: str = "6379"
|
||||
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
|
||||
@@ -80,6 +82,7 @@ class Settings(BaseSettings):
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
|
||||
@@ -5,4 +5,4 @@ from sqlmodel import create_engine
|
||||
if settings.DATABASE_URI is None:
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
engine = create_engine(settings.DATABASE_URI, echo=settings.DEBUG_DATABASE_ECHO, isolation_level="REPEATABLE READ")
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Task, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
@@ -80,6 +81,7 @@ class JournalWriter:
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def log(
|
||||
self,
|
||||
*,
|
||||
@@ -115,7 +117,4 @@ class JournalWriter:
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
if commit:
|
||||
self.db.commit()
|
||||
|
||||
return entry
|
||||
|
||||
@@ -48,6 +48,8 @@ def payload_column_type(pydantic_type):
|
||||
class PayloadJSONBType(TypeDecorator, Generic[T]):
|
||||
impl = pg.JSONB()
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_encoder=json,
|
||||
|
||||
@@ -24,6 +24,7 @@ from oasst_backend.models import (
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
@@ -67,6 +68,7 @@ class PromptRepository:
|
||||
)
|
||||
return message
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
@@ -104,8 +106,8 @@ class PromptRepository:
|
||||
review_result=review_result,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
# self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def _validate_task(
|
||||
@@ -134,6 +136,7 @@ class PromptRepository:
|
||||
def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
@@ -150,7 +153,7 @@ class PromptRepository:
|
||||
self._validate_task(task)
|
||||
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "prompter"
|
||||
role = None
|
||||
depth = 0
|
||||
|
||||
if task.parent_message_id:
|
||||
@@ -170,10 +173,23 @@ class PromptRepository:
|
||||
self.db.add(parent_message)
|
||||
|
||||
depth = parent_message.depth + 1
|
||||
if parent_message.role == "assistant":
|
||||
role = "prompter"
|
||||
else:
|
||||
role = "assistant"
|
||||
|
||||
task_payload: db_payload.TaskPayload = task.payload.payload
|
||||
if isinstance(task_payload, db_payload.InitialPromptPayload):
|
||||
role = "prompter"
|
||||
elif isinstance(task_payload, db_payload.PrompterReplyPayload):
|
||||
role = "prompter"
|
||||
elif isinstance(task_payload, db_payload.AssistantReplyPayload):
|
||||
role = "assistant"
|
||||
elif isinstance(task_payload, db_payload.SummarizationStoryPayload):
|
||||
raise NotImplementedError("SummarizationStory task not implemented.")
|
||||
else:
|
||||
raise OasstError(
|
||||
f"Unexpected task payload type: {type(task_payload).__name__}",
|
||||
OasstErrorCode.TASK_UNEXPECTED_PAYLOAD_TYPE_,
|
||||
)
|
||||
|
||||
assert role in ("assistant", "prompter")
|
||||
|
||||
# create reply message
|
||||
new_message_id = uuid4()
|
||||
@@ -192,10 +208,10 @@ class PromptRepository:
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
return user_message
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
@@ -225,6 +241,7 @@ class PromptRepository:
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
|
||||
# fetch task
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
@@ -297,6 +314,7 @@ class PromptRepository:
|
||||
|
||||
return reaction, task
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity:
|
||||
"""Save the toxicity score of a new message in the database.
|
||||
Args:
|
||||
@@ -312,10 +330,9 @@ class PromptRepository:
|
||||
|
||||
message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label)
|
||||
self.db.add(message_toxicity)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_toxicity)
|
||||
return message_toxicity
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
@@ -333,10 +350,9 @@ class PromptRepository:
|
||||
|
||||
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(message_embedding)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_embedding)
|
||||
return message_embedding
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
@@ -350,10 +366,9 @@ class PromptRepository:
|
||||
payload_type=type(payload).__name__,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
|
||||
|
||||
valid_labels: Optional[list[str]] = None
|
||||
@@ -423,8 +438,6 @@ class PromptRepository:
|
||||
self.db.add(message)
|
||||
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model, task, message
|
||||
|
||||
def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]:
|
||||
@@ -689,6 +702,7 @@ class PromptRepository:
|
||||
|
||||
return messages.all()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
Marks deleted messages and all their descendants.
|
||||
@@ -717,8 +731,6 @@ class PromptRepository:
|
||||
|
||||
parent_ids = self.db.execute(query).scalars().all()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_stats(self) -> SystemStats:
|
||||
"""
|
||||
Get data stats such as number of all messages in the system,
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -128,6 +129,7 @@ class TaskRepository:
|
||||
assert task_model.id == task.id
|
||||
return task_model
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
@@ -142,10 +144,9 @@ class TaskRepository:
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
@@ -166,8 +167,8 @@ class TaskRepository:
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
@@ -181,8 +182,8 @@ class TaskRepository:
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
@@ -204,8 +205,6 @@ class TaskRepository:
|
||||
)
|
||||
logger.debug(f"inserting {task=}")
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
|
||||
@@ -9,13 +9,14 @@ import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session, func, not_
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -48,6 +49,7 @@ class ActiveTreeSizeRow(pydantic.BaseModel):
|
||||
|
||||
class ExtendibleParentRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
parent_role: str
|
||||
depth: int
|
||||
message_tree_id: UUID
|
||||
active_children_count: int
|
||||
@@ -58,6 +60,7 @@ class ExtendibleParentRow(pydantic.BaseModel):
|
||||
|
||||
class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
parent_id: UUID
|
||||
role: str
|
||||
children_count: int
|
||||
child_min_ranking_count: int
|
||||
|
||||
@@ -69,21 +72,23 @@ class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
def __init__(
|
||||
self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None
|
||||
self,
|
||||
db: Session,
|
||||
prompt_repository: PromptRepository,
|
||||
cfg: Optional[TreeManagerConfiguration] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.cfg = cfg or settings.tree_manager
|
||||
self.pr = prompt_repository
|
||||
|
||||
def _task_selection(
|
||||
def _random_task_selection(
|
||||
self,
|
||||
desired_task_type: protocol_schema.TaskRequestType,
|
||||
num_ranking_tasks: int,
|
||||
num_replies_need_review: int,
|
||||
num_prompts_need_review: int,
|
||||
num_missing_prompts: int,
|
||||
num_missing_replies: int,
|
||||
) -> Tuple[TaskType, TaskRole]:
|
||||
) -> TaskType:
|
||||
"""
|
||||
Determines which task to hand out to human worker.
|
||||
The task type is drawn with relative weight (e.g. ranking has highest priority)
|
||||
@@ -91,75 +96,97 @@ class TreeManager:
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
||||
f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, "
|
||||
f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})"
|
||||
)
|
||||
|
||||
task_type = TaskType.NONE
|
||||
task_role = TaskRole.ANY
|
||||
if desired_task_type == protocol_schema.TaskRequestType.random:
|
||||
task_weights = [0] * 5
|
||||
task_weights = [0] * 5
|
||||
|
||||
if num_ranking_tasks > 0:
|
||||
task_weights[TaskType.RANKING.value] = 10
|
||||
if num_ranking_tasks > 0:
|
||||
task_weights[TaskType.RANKING.value] = 10
|
||||
|
||||
if num_replies_need_review > 0:
|
||||
task_weights[TaskType.LABEL_REPLY.value] = 5
|
||||
if num_replies_need_review > 0:
|
||||
task_weights[TaskType.LABEL_REPLY.value] = 5
|
||||
|
||||
if num_prompts_need_review > 0:
|
||||
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
||||
if num_prompts_need_review > 0:
|
||||
task_weights[TaskType.LABEL_PROMPT.value] = 5
|
||||
|
||||
if num_missing_replies > 0:
|
||||
task_weights[TaskType.REPLY.value] = 2
|
||||
if num_missing_replies > 0:
|
||||
task_weights[TaskType.REPLY.value] = 2
|
||||
|
||||
if num_missing_prompts > 0:
|
||||
task_weights[TaskType.PROMPT.value] = 1
|
||||
if num_missing_prompts > 0:
|
||||
task_weights[TaskType.PROMPT.value] = 1
|
||||
|
||||
task_weights = np.array(task_weights)
|
||||
weight_sum = task_weights.sum()
|
||||
if weight_sum < 1e-8:
|
||||
task_type = TaskType.NONE
|
||||
else:
|
||||
task_weights = task_weights / weight_sum
|
||||
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
||||
else:
|
||||
match desired_task_type:
|
||||
case protocol_schema.TaskRequestType.initial_prompt:
|
||||
if num_missing_prompts > 0:
|
||||
task_type = TaskType.PROMPT
|
||||
case protocol_schema.TaskRequestType.label_initial_prompt:
|
||||
if num_prompts_need_review > 0:
|
||||
task_type = TaskType.LABEL_PROMPT
|
||||
case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply:
|
||||
if num_missing_replies > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.REPLY
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
if num_replies_need_review > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.LABEL_REPLY
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
if num_ranking_tasks > 0:
|
||||
task_role = (
|
||||
TaskRole.ASSISTANT
|
||||
if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies
|
||||
else TaskRole.PROMPTER
|
||||
)
|
||||
task_type = TaskType.RANKING
|
||||
task_weights = np.array(task_weights)
|
||||
weight_sum = task_weights.sum()
|
||||
if weight_sum > 1e-8:
|
||||
task_weights = task_weights / weight_sum
|
||||
task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights))
|
||||
|
||||
logger.debug(f"Selected {task_type=}, {task_role=}")
|
||||
return task_type, task_role
|
||||
logger.debug(f"Selected {task_type=}")
|
||||
return task_type
|
||||
|
||||
def _determine_task_availability_internal(
|
||||
self,
|
||||
num_active_trees: int,
|
||||
extensible_parents: list[ExtendibleParentRow],
|
||||
prompts_need_review: list[Message],
|
||||
replies_need_review: list[Message],
|
||||
incomplete_rankings: list[IncompleteRankingsRow],
|
||||
) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType}
|
||||
|
||||
num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
|
||||
list(filter(lambda m: m.role == "assistant", replies_need_review))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
)
|
||||
|
||||
if self.cfg.rank_prompter_replies:
|
||||
task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len(
|
||||
list(filter(lambda r: r.role == "prompter", incomplete_rankings))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len(
|
||||
list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.random] = sum(
|
||||
task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type
|
||||
)
|
||||
|
||||
return task_count_by_type
|
||||
|
||||
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
def next_task(
|
||||
self, desired_task_type: protocol_schema.TaskRequestType
|
||||
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
@@ -167,148 +194,195 @@ class TreeManager:
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
if not self.cfg.rank_prompter_replies:
|
||||
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
|
||||
|
||||
active_tree_sizes = self.query_extendible_trees()
|
||||
|
||||
# determine type of task to generate
|
||||
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
|
||||
|
||||
task_type, task_role = self._task_selection(
|
||||
desired_task_type,
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
if task_type == TaskType.NONE:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
task_role = TaskRole.ANY
|
||||
if desired_task_type == protocol_schema.TaskRequestType.random:
|
||||
task_type = self._random_task_selection(
|
||||
num_ranking_tasks=len(incomplete_rankings),
|
||||
num_replies_need_review=len(replies_need_review),
|
||||
num_prompts_need_review=len(prompts_need_review),
|
||||
num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees),
|
||||
num_missing_replies=num_missing_replies,
|
||||
)
|
||||
|
||||
if task_role != TaskRole.ANY:
|
||||
# Todo: Allow role specific message selection...
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
if task_type == TaskType.NONE:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
else:
|
||||
task_count_by_type = self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
)
|
||||
|
||||
available_count = task_count_by_type.get(desired_task_type)
|
||||
if not available_count:
|
||||
raise OasstError(
|
||||
f"No tasks of type '{desired_task_type.value}' are currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
task_type_role_map = {
|
||||
protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY),
|
||||
protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER),
|
||||
protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER),
|
||||
protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY),
|
||||
protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT),
|
||||
protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER),
|
||||
}
|
||||
|
||||
task_type, task_role = task_type_role_map[desired_task_type]
|
||||
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
logger.debug(f"selected {task_type=}")
|
||||
match task_type:
|
||||
case TaskType.RANKING:
|
||||
assert len(incomplete_rankings) > 0
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings))
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
if len(incomplete_rankings) > 0:
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True)
|
||||
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
assert len(replies) > 1
|
||||
random.shuffle(replies) # hand out replies in random order
|
||||
reply_messages = prepare_conversation_message_list(replies)
|
||||
replies = [p.text for p in replies]
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
else:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=conversation,
|
||||
replies=replies,
|
||||
reply_messages=reply_messages,
|
||||
ranking_parent_id=ranking_parent.id,
|
||||
message_tree_id=ranking_parent.message_tree_id,
|
||||
)
|
||||
|
||||
parent_message_id = ranking_parent_id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_REPLY:
|
||||
assert len(replies_need_review) > 0
|
||||
random_reply_message_id = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message_id)
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review))
|
||||
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
if len(replies_need_review) > 0:
|
||||
random_reply_message = random.choice(replies_need_review)
|
||||
messages = self.pr.fetch_message_conversation(random_reply_message)
|
||||
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
conversation = prepare_conversation(messages[:-1])
|
||||
message = messages[-1]
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
self.cfg.p_full_labeling_review_reply_prompter: float = 0.1
|
||||
|
||||
if message.role == "assistant":
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_assistant:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
else:
|
||||
if random.random() > self.cfg.p_full_labeling_review_reply_prompter:
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
if message.role == "assistant":
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_assistant
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
desired_task_type == protocol_schema.TaskRequestType.random
|
||||
and random.random() > self.cfg.p_full_labeling_review_reply_prompter
|
||||
):
|
||||
valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply))
|
||||
label_mode = protocol_schema.LabelTaskMode.simple
|
||||
logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})")
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
reply=message.text,
|
||||
valid_labels=valid_labels,
|
||||
mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)),
|
||||
mode=label_mode,
|
||||
)
|
||||
|
||||
parent_message_id = message.id
|
||||
message_tree_id = message.message_tree_id
|
||||
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
assert len(extensible_parents) > 0
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
|
||||
# fetch random conversation to extend
|
||||
random_parent = random.choice(extensible_parents)
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
conversation = prepare_conversation(messages)
|
||||
if len(extensible_parents) > 0:
|
||||
random_parent = random.choice(extensible_parents)
|
||||
|
||||
# generate reply task depending on last message
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
||||
else:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
messages = self.pr.fetch_message_conversation(random_parent.parent_id)
|
||||
assert all(m.review_result for m in messages) # ensure all messages have positive review
|
||||
conversation = prepare_conversation(messages)
|
||||
|
||||
parent_message_id = messages[-1].id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
# generate reply task depending on last message
|
||||
if messages[-1].role == "assistant":
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
task = protocol_schema.PrompterReplyTask(conversation=conversation)
|
||||
else:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
task = protocol_schema.AssistantReplyTask(conversation=conversation)
|
||||
|
||||
parent_message_id = messages[-1].id
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
|
||||
case TaskType.LABEL_PROMPT:
|
||||
assert len(prompts_need_review) > 0
|
||||
message = self.pr.fetch_message(random.choice(prompts_need_review))
|
||||
message = random.choice(prompts_need_review)
|
||||
|
||||
label_mode = protocol_schema.LabelTaskMode.full
|
||||
valid_labels = self._all_text_labels
|
||||
@@ -336,10 +410,18 @@ class TreeManager:
|
||||
case _:
|
||||
task = None
|
||||
|
||||
if task is None:
|
||||
raise OasstError(
|
||||
f"No task of type '{desired_task_type.value}' is currently available.",
|
||||
OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE,
|
||||
HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
@async_managed_tx_method(CommitMode.COMMIT)
|
||||
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
||||
pr = self.pr
|
||||
match type(interaction):
|
||||
@@ -358,7 +440,6 @@ class TreeManager:
|
||||
if not message.parent_id:
|
||||
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
|
||||
self._insert_default_state(message.id)
|
||||
self.db.commit()
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
@@ -428,7 +509,6 @@ class TreeManager:
|
||||
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
@@ -439,7 +519,6 @@ class TreeManager:
|
||||
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
@@ -451,6 +530,7 @@ class TreeManager:
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
|
||||
@@ -460,7 +540,6 @@ class TreeManager:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.commit()
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
@@ -472,6 +551,7 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
|
||||
@@ -489,6 +569,7 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
||||
|
||||
@@ -514,7 +595,8 @@ class TreeManager:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id)
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
@@ -527,68 +609,59 @@ class TreeManager:
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
_sql_find_prompts_need_review = """
|
||||
-- find initial prompts that need more reviews
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.id
|
||||
WHERE mts.active
|
||||
AND mts.state = :state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_reviews_initial_prompt
|
||||
AND m.parent_id is NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
|
||||
def query_prompts_need_review(self) -> list[UUID]:
|
||||
def query_prompts_need_review(self) -> list[Message]:
|
||||
"""
|
||||
Select id of initial prompts with less then required rankings in active message tree
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_prompts_need_review),
|
||||
{
|
||||
"state": message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
"num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
||||
Message.parent_id.is_(None),
|
||||
)
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
_sql_find_replies_need_review = """
|
||||
SELECT m.id
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND mts.state = :breeding_state
|
||||
AND NOT m.review_result
|
||||
AND NOT m.deleted
|
||||
AND m.review_count < :num_required_reviews
|
||||
AND m.parent_id is NOT NULL
|
||||
AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id)
|
||||
"""
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
def query_replies_need_review(self) -> list[UUID]:
|
||||
return qry.all()
|
||||
|
||||
def query_replies_need_review(self) -> list[Message]:
|
||||
"""
|
||||
Select ids of child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_replies_need_review),
|
||||
{
|
||||
"breeding_state": message_tree_state.State.GROWING,
|
||||
"num_required_reviews": self.cfg.num_reviews_reply,
|
||||
"excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id,
|
||||
},
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_reply,
|
||||
Message.parent_id.is_not(None),
|
||||
)
|
||||
)
|
||||
return [x["id"] for x in r.all()]
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
return qry.all()
|
||||
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
@@ -597,7 +670,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id
|
||||
GROUP BY m.parent_id, m.role
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
@@ -615,10 +688,10 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
|
||||
_sql_find_extendible_parents = """
|
||||
-- find all extendible parent nodes
|
||||
SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message c ON m.id = c.Id -- child nodes
|
||||
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
AND NOT m.deleted -- ignore deleted messages as parents
|
||||
@@ -626,7 +699,7 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND NOT c.deleted -- don't count deleted children
|
||||
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"""
|
||||
|
||||
@@ -635,10 +708,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
},
|
||||
{"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
|
||||
@@ -670,21 +740,27 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
|
||||
_sql_get_tree_size = """
|
||||
SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active
|
||||
AND NOT m.deleted
|
||||
AND m.review_result
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY mts.message_tree_id, mts.goal_tree_size
|
||||
"""
|
||||
|
||||
def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow:
|
||||
"""Returns the number of reviewed not deleted messages in the message tree."""
|
||||
r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id})
|
||||
return ActiveTreeSizeRow.from_orm(r.one())
|
||||
|
||||
qry = (
|
||||
self.db.query(
|
||||
MessageTreeState.message_tree_id.label("message_tree_id"),
|
||||
MessageTreeState.goal_tree_size.label("goal_tree_size"),
|
||||
func.count(Message.id).label("tree_size"),
|
||||
)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
not_(Message.deleted),
|
||||
Message.review_result,
|
||||
MessageTreeState.message_tree_id == message_tree_id,
|
||||
)
|
||||
.group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size)
|
||||
)
|
||||
|
||||
return ActiveTreeSizeRow.from_orm(qry.one())
|
||||
|
||||
def query_misssing_tree_states(self) -> list[UUID]:
|
||||
"""Find all initial prompt messages that have no associated message tree state"""
|
||||
@@ -701,7 +777,7 @@ GROUP BY mts.message_tree_id, mts.goal_tree_size
|
||||
return [m.id for m in qry_missing_tree_states.all()]
|
||||
|
||||
_sql_find_tree_ranking_results = """
|
||||
-- get all ranking results of completed tasks for all parents with >=2 children
|
||||
-- get all ranking results of completed tasks for all parents with >= 2 children
|
||||
SELECT p.parent_id, mr.* FROM
|
||||
(
|
||||
-- find parents with > 1 children
|
||||
@@ -711,7 +787,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
WHERE m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
AND (:role IS NULL OR m.role = :role) -- children with matching role
|
||||
AND mts.message_tree_id = :message_tree_id
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
@@ -719,11 +796,21 @@ LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]:
|
||||
def query_tree_ranking_results(
|
||||
self,
|
||||
message_tree_id: UUID,
|
||||
role_filter: str = "assistant",
|
||||
) -> dict[UUID, list[MessageReaction]]:
|
||||
"""Finds all completed ranking restuls for a message_tree"""
|
||||
|
||||
assert role_filter in (None, "assistant", "prompter")
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_tree_ranking_results),
|
||||
{"message_tree_id": message_tree_id},
|
||||
{
|
||||
"message_tree_id": message_tree_id,
|
||||
"role": role_filter,
|
||||
},
|
||||
)
|
||||
|
||||
rankings_by_message = {}
|
||||
@@ -735,6 +822,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
|
||||
return rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def ensure_tree_states(self):
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
|
||||
@@ -746,23 +834,21 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
state = message_tree_state.State.GROWING
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
|
||||
self._insert_default_state(id, state=state)
|
||||
self.db.commit()
|
||||
|
||||
def query_num_active_trees(self) -> int:
|
||||
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
|
||||
return query.scalar()
|
||||
|
||||
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
|
||||
sql_qry = """
|
||||
SELECT tl.*
|
||||
FROM task t
|
||||
INNER JOIN text_labels tl ON tl.id = t.id
|
||||
WHERE t.done = TRUE
|
||||
AND tl.message_id = :message_id
|
||||
"""
|
||||
r = self.db.execute(text(sql_qry), {"message_id": message_id})
|
||||
return [TextLabels.from_orm(x) for x in r.all()]
|
||||
qry = (
|
||||
self.db.query(TextLabels)
|
||||
.select_from(Task)
|
||||
.join(TextLabels, Task.id == TextLabels.id)
|
||||
.filter(Task.done, TextLabels.message_id == message_id)
|
||||
)
|
||||
return qry.all()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_tree_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
@@ -784,6 +870,7 @@ WHERE t.done = TRUE
|
||||
self.db.add(model)
|
||||
return model
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_default_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
@@ -800,12 +887,12 @@ WHERE t.done = TRUE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.deps import api_auth
|
||||
from oasst_backend.database import engine
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
|
||||
with Session(engine) as db:
|
||||
api_client = get_dummy_api_client(db)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user)
|
||||
@@ -814,15 +901,21 @@ if __name__ == "__main__":
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
print("query_num_active_trees", tm.query_num_active_trees())
|
||||
print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
print("query_incomplete_reply_reviews", tm.query_replies_need_review())
|
||||
print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
print("query_extendible_trees", tm.query_extendible_trees())
|
||||
print("query_extendible_parents", tm.query_extendible_parents())
|
||||
|
||||
print("next_task:", tm.next_task())
|
||||
# print("query_num_active_trees", tm.query_num_active_trees())
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
# print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review())
|
||||
# print("query_extendible_trees", tm.query_extendible_trees())
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
||||
|
||||
print(
|
||||
".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921"))
|
||||
"query_reviews_for_message",
|
||||
tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
||||
)
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
# print(
|
||||
# "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312"))
|
||||
# )
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -62,6 +63,7 @@ class UserRepository:
|
||||
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
@@ -83,8 +85,8 @@ class UserRepository:
|
||||
user.notes = notes
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
"""
|
||||
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
|
||||
@@ -103,8 +105,8 @@ class UserRepository:
|
||||
user.deleted = True
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
@@ -127,20 +129,17 @@ class UserRepository:
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def query_users(
|
||||
self,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
limit: Optional[int] = 20,
|
||||
gte: Optional[str] = None,
|
||||
gt: Optional[str] = None,
|
||||
lt: Optional[str] = None,
|
||||
auth_method: Optional[str] = None,
|
||||
) -> list[User]:
|
||||
@@ -161,8 +160,8 @@ class UserRepository:
|
||||
|
||||
users = users.order_by(User.display_name)
|
||||
|
||||
if gte:
|
||||
users = users.filter(User.display_name >= gte)
|
||||
if gt:
|
||||
users = users.filter(User.display_name > gt)
|
||||
|
||||
if lt:
|
||||
users = users.filter(User.display_name < lt)
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from http import HTTPStatus
|
||||
from typing import Callable
|
||||
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
|
||||
class CommitMode(IntEnum):
|
||||
"""
|
||||
Commit modes for the managed tx methods
|
||||
"""
|
||||
|
||||
NONE = 0
|
||||
FLUSH = 1
|
||||
COMMIT = 2
|
||||
|
||||
|
||||
"""
|
||||
* managed_tx_method and async_managed_tx_method methods are decorators functions
|
||||
* to be used on class functions. It expects the Class to have a 'db' Session object
|
||||
* initialised
|
||||
* TODO: tx method decorator for non class methods
|
||||
"""
|
||||
|
||||
|
||||
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_managed_tx_method(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
self.db.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def default_session_factor() -> Session:
|
||||
return Session(engine)
|
||||
|
||||
|
||||
def managed_tx_function(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT,
|
||||
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
|
||||
session_factory: Callable[..., Session] = default_session_factor,
|
||||
):
|
||||
"""Passes Session object as first argument to wrapped function."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
session.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,8 +29,6 @@ environments:
|
||||
variables:
|
||||
# Note: this has to be a valid JSON list for Pydantic to parse it.
|
||||
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: True
|
||||
DEBUG_SKIP_API_KEY_CHECK: True
|
||||
MAX_WORKERS: 1
|
||||
|
||||
secrets:
|
||||
|
||||
@@ -16,6 +16,19 @@ http {
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
server_name www.open-assistant.io;
|
||||
|
||||
ssl_certificate /etc/nginx/ssl/live/www.open-assistant.io/fullchain.pem;
|
||||
ssl_certificate_key /etc/nginx/ssl/live/www.open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
return 301 https://open-assistant.io$request_uri;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
|
||||
@@ -25,7 +38,9 @@ http {
|
||||
ssl_certificate_key /etc/nginx/ssl/live/open-assistant.io/privkey.pem;
|
||||
|
||||
location / {
|
||||
return 301 https://web.prod.open-assistant.io$request_uri;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:3200;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +55,7 @@ http {
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:3000;
|
||||
proxy_pass http://127.0.0.1:3200;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +70,7 @@ http {
|
||||
location / {
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
proxy_pass http://127.0.0.1:8280;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,6 @@ services:
|
||||
environment:
|
||||
- POSTGRES_HOST=db
|
||||
- REDIS_HOST=redis
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- DEBUG_ALLOW_SELF_LABELING=True
|
||||
- MAX_WORKERS=1
|
||||
|
||||
@@ -18,6 +18,7 @@ class OasstErrorCode(IntEnum):
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
ROOT_TOKEN_NOT_AUTHORIZED = 3
|
||||
DATABASE_MAX_RETRIES_EXHAUSTED = 4
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
@@ -31,6 +32,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_INTERACTION_REQUEST_FAILED = 1004
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006
|
||||
TASK_AVAILABILITY_QUERY_FAILED = 1007
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
@@ -59,6 +61,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_ASSIGNED_TO_USER = 2106
|
||||
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
|
||||
@@ -4,7 +4,6 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=False
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
|
||||
@@ -25,7 +25,7 @@ def _render_message(message: dict) -> str:
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY"):
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
|
||||
Generated
+45
-95
@@ -29,7 +29,6 @@
|
||||
"eslint-config-next": "13.0.6",
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
@@ -40,7 +39,9 @@
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
@@ -20303,47 +20304,6 @@
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/formik": {
|
||||
"version": "2.2.9",
|
||||
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
|
||||
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://opencollective.com/formik"
|
||||
}
|
||||
],
|
||||
"dependencies": {
|
||||
"deepmerge": "^2.1.1",
|
||||
"hoist-non-react-statics": "^3.3.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"react-fast-compare": "^2.0.1",
|
||||
"tiny-warning": "^1.0.2",
|
||||
"tslib": "^1.10.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/formik/node_modules/deepmerge": {
|
||||
"version": "2.2.1",
|
||||
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
|
||||
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/formik/node_modules/react-fast-compare": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
|
||||
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
|
||||
},
|
||||
"node_modules/formik/node_modules/tslib": {
|
||||
"version": "1.14.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
|
||||
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
|
||||
},
|
||||
"node_modules/forwarded": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
|
||||
@@ -26442,12 +26402,8 @@
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"node_modules/lodash-es": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
|
||||
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -32526,6 +32482,21 @@
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/react-hook-form": {
|
||||
"version": "7.42.1",
|
||||
"resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.42.1.tgz",
|
||||
"integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==",
|
||||
"engines": {
|
||||
"node": ">=12.22.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/react-hook-form"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16.8.0 || ^17 || ^18"
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
@@ -32625,6 +32596,18 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-table": {
|
||||
"version": "7.8.0",
|
||||
"resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz",
|
||||
"integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16.8.3 || ^17.0.0-0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/read-cache": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz",
|
||||
@@ -35473,11 +35456,6 @@
|
||||
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
|
||||
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
|
||||
},
|
||||
"node_modules/tiny-warning": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
|
||||
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
|
||||
},
|
||||
"node_modules/tmp": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
|
||||
@@ -52916,37 +52894,6 @@
|
||||
"mime-types": "^2.1.12"
|
||||
}
|
||||
},
|
||||
"formik": {
|
||||
"version": "2.2.9",
|
||||
"resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz",
|
||||
"integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==",
|
||||
"requires": {
|
||||
"deepmerge": "^2.1.1",
|
||||
"hoist-non-react-statics": "^3.3.0",
|
||||
"lodash": "^4.17.21",
|
||||
"lodash-es": "^4.17.21",
|
||||
"react-fast-compare": "^2.0.1",
|
||||
"tiny-warning": "^1.0.2",
|
||||
"tslib": "^1.10.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"deepmerge": {
|
||||
"version": "2.2.1",
|
||||
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz",
|
||||
"integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA=="
|
||||
},
|
||||
"react-fast-compare": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz",
|
||||
"integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw=="
|
||||
},
|
||||
"tslib": {
|
||||
"version": "1.14.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz",
|
||||
"integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg=="
|
||||
}
|
||||
}
|
||||
},
|
||||
"forwarded": {
|
||||
"version": "0.2.0",
|
||||
"resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
|
||||
@@ -57561,12 +57508,8 @@
|
||||
"lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg=="
|
||||
},
|
||||
"lodash-es": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz",
|
||||
"integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw=="
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true
|
||||
},
|
||||
"lodash.debounce": {
|
||||
"version": "4.0.8",
|
||||
@@ -61939,6 +61882,12 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"react-hook-form": {
|
||||
"version": "7.42.1",
|
||||
"resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.42.1.tgz",
|
||||
"integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==",
|
||||
"requires": {}
|
||||
},
|
||||
"react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
@@ -61993,6 +61942,12 @@
|
||||
"tslib": "^2.0.0"
|
||||
}
|
||||
},
|
||||
"react-table": {
|
||||
"version": "7.8.0",
|
||||
"resolved": "https://registry.npmjs.org/react-table/-/react-table-7.8.0.tgz",
|
||||
"integrity": "sha512-hNaz4ygkZO4bESeFfnfOft73iBUj8K5oKi1EcSHPAibEydfsX2MyU6Z8KCr3mv3C9Kqqh71U+DhZkFvibbnPbA==",
|
||||
"requires": {}
|
||||
},
|
||||
"read-cache": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz",
|
||||
@@ -64246,11 +64201,6 @@
|
||||
"resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz",
|
||||
"integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw=="
|
||||
},
|
||||
"tiny-warning": {
|
||||
"version": "1.0.3",
|
||||
"resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz",
|
||||
"integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA=="
|
||||
},
|
||||
"tmp": {
|
||||
"version": "0.2.1",
|
||||
"resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz",
|
||||
|
||||
@@ -46,7 +46,6 @@
|
||||
"eslint-config-next": "13.0.6",
|
||||
"eslint-plugin-simple-import-sort": "^8.0.0",
|
||||
"focus-visible": "^5.2.0",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.5.1",
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
@@ -57,7 +56,9 @@
|
||||
"react": "18.2.0",
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import { Box, Link, Stack, StackDivider, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import NextLink from "next/link";
|
||||
import { get } from "src/lib/api";
|
||||
import useSWR from "swr";
|
||||
import { LeaderboardGridCell } from "src/components/LeaderboardGridCell";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
export function LeaderboardTable() {
|
||||
const backgroundColor = useColorModeValue("white", "gray.700");
|
||||
const accentColor = useColorModeValue("gray.200", "gray.900");
|
||||
const { data: leaderboardEntries } = useSWR("/api/leaderboard", get);
|
||||
return (
|
||||
<main className="h-fit col-span-3">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex items-end justify-between">
|
||||
<Text className="text-2xl font-bold">Top 5 Contributors</Text>
|
||||
<Text className="text-2xl font-bold">Top 5 Contributors Today</Text>
|
||||
<Link as={NextLink} href="/leaderboard" _hover={{ textDecoration: "none" }}>
|
||||
<Text color="blue.400" className="text-sm font-bold">
|
||||
View All ->
|
||||
@@ -25,30 +24,7 @@ export function LeaderboardTable() {
|
||||
borderRadius="xl"
|
||||
className="p-6 shadow-sm"
|
||||
>
|
||||
<Stack divider={<StackDivider />} spacing="4">
|
||||
<div className="grid grid-cols-4 items-center font-bold">
|
||||
<p>Name</p>
|
||||
<div className="col-start-4 flex justify-center">
|
||||
<p>Score</p>
|
||||
</div>
|
||||
</div>
|
||||
{leaderboardEntries?.map(({ display_name, score }, idx) => (
|
||||
<div key={idx} className="grid grid-cols-4 items-center">
|
||||
<div className="flex items-center gap-3">
|
||||
{/*
|
||||
<Image alt="Profile Picture" src={item.image} boxSize="7" borderRadius="full"></Image>
|
||||
*/}
|
||||
<p>{display_name}</p>
|
||||
{/*
|
||||
<Badge colorScheme="purple">{item.streakCount}</Badge>
|
||||
*/}
|
||||
</div>
|
||||
<Box bg={backgroundColor} className="col-start-4 flex justify-center">
|
||||
<p>{score}</p>
|
||||
</Box>
|
||||
</div>
|
||||
))}
|
||||
</Stack>
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.day} />
|
||||
</Box>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
@@ -1,97 +1,73 @@
|
||||
import { Avatar, Box, Grid, GridItem, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { FiChevronDown } from "react-icons/fi";
|
||||
import { Table, TableContainer, Tbody, Td, Th, Thead, Tr, useColorModeValue } from "@chakra-ui/react";
|
||||
import React from "react";
|
||||
import { useTable } from "react-table";
|
||||
import { get } from "src/lib/api";
|
||||
import useSWR from "swr";
|
||||
import { LeaderboardEntity, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const columns = [
|
||||
{
|
||||
Header: "Rank",
|
||||
accessor: (item: LeaderboardEntity, rowIndex: number) => "#" + item.rank,
|
||||
style: { width: "90px" },
|
||||
},
|
||||
{
|
||||
Header: "Score",
|
||||
accessor: "leader_score",
|
||||
style: { width: "90px" },
|
||||
},
|
||||
{
|
||||
Header: "User",
|
||||
accessor: "display_name",
|
||||
},
|
||||
];
|
||||
|
||||
/**
|
||||
* Presents a grid of leaderboard entries with more detailed information.
|
||||
*/
|
||||
const LeaderboardGridCell = () => {
|
||||
const { data: leaderboardEntries } = useSWR("/api/leaderboard", get);
|
||||
const LeaderboardGridCell = ({ timeFrame }: { timeFrame: LeaderboardTimeFrame }) => {
|
||||
const { data } = useSWRImmutable<LeaderboardEntity[]>(`/api/leaderboard?time_frame=${timeFrame}`, get, {
|
||||
fallbackData: [],
|
||||
revalidateOnMount: true,
|
||||
});
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
const columns = `repeat(${FILTER.length}, 1fr)`;
|
||||
|
||||
const { getTableProps, getTableBodyProps, headerGroups, rows, prepareRow } = useTable({ columns, data });
|
||||
|
||||
return (
|
||||
<>
|
||||
<Grid>
|
||||
<GridItem
|
||||
colSpan={4}
|
||||
bg={backgroundColor}
|
||||
display="grid"
|
||||
gridTemplateColumns={columns}
|
||||
p="4"
|
||||
borderRadius="lg"
|
||||
mb="4"
|
||||
shadow="base"
|
||||
>
|
||||
{FILTER.map(({ title, GridItemProps }, index) => (
|
||||
<GridItem key={index} display="flex" {...GridItemProps}>
|
||||
<Box display="flex" alignItems="center" gap="2" width="fit-content" borderRadius="md" cursor="pointer">
|
||||
<Text fontSize="sm" fontWeight="bold" textTransform="uppercase">
|
||||
{title}
|
||||
</Text>
|
||||
|
||||
<FiChevronDown size="16" />
|
||||
</Box>
|
||||
</GridItem>
|
||||
<TableContainer>
|
||||
<Table {...getTableProps()}>
|
||||
<Thead bg={backgroundColor}>
|
||||
{headerGroups.map((headerGroup, idx) => (
|
||||
<Tr key={idx} {...headerGroup.getHeaderGroupProps()}>
|
||||
{headerGroup.headers.map((column) => (
|
||||
<Th {...column.getHeaderProps([{ style: column.style }])} key={column.id}>
|
||||
{column.render("Header")}
|
||||
</Th>
|
||||
))}
|
||||
</Tr>
|
||||
))}
|
||||
</GridItem>
|
||||
</Grid>
|
||||
<Grid templateColumns={columns} bg={backgroundColor} borderRadius="xl" shadow="base" p="4" gap="6">
|
||||
{leaderboardEntries?.map(({ display_name, ranking, score }, index) => (
|
||||
<GridItem key={index} colSpan={4} display="grid" gridTemplateColumns={columns} borderRadius="lg" p="2">
|
||||
<GridItem overflow="hidden">
|
||||
<Box display="flex" gap="2">
|
||||
<Avatar size="xs" />
|
||||
<Text>{display_name}</Text>
|
||||
</Box>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<GridItem display="flex" justifyContent="center">
|
||||
<Text>{ranking}</Text>
|
||||
</GridItem>
|
||||
</GridItem>
|
||||
<GridItem display="flex" justifyContent="center">
|
||||
<Text>{score}</Text>
|
||||
</GridItem>
|
||||
{/*
|
||||
<GridItem display="flex" justifyContent="center">
|
||||
<Text fontSize="xl">{item.medal}</Text>
|
||||
</GridItem>
|
||||
*/}
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
</>
|
||||
</Thead>
|
||||
|
||||
<Tbody {...getTableBodyProps()}>
|
||||
{rows.map((row) => {
|
||||
prepareRow(row);
|
||||
return (
|
||||
<Tr key={row.id} {...row.getRowProps()}>
|
||||
{row.cells.map((cell, idx) => {
|
||||
return (
|
||||
<Td key={row.id + idx} {...cell.getCellProps([{ style: cell.column.style }])}>
|
||||
{cell.render("Cell")}
|
||||
</Td>
|
||||
);
|
||||
})}
|
||||
</Tr>
|
||||
);
|
||||
})}
|
||||
</Tbody>
|
||||
</Table>
|
||||
</TableContainer>
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Specifies the table headers in the grid.
|
||||
*/
|
||||
const FILTER = [
|
||||
{
|
||||
title: "User",
|
||||
isActive: false,
|
||||
GridItemProps: { justifyContent: "start" },
|
||||
},
|
||||
{
|
||||
title: "Rank",
|
||||
isActive: false,
|
||||
GridItemProps: { justifyContent: "center" },
|
||||
},
|
||||
{
|
||||
title: "Score",
|
||||
isActive: false,
|
||||
GridItemProps: { justifyContent: "center" },
|
||||
},
|
||||
/*
|
||||
{
|
||||
title: "Medal",
|
||||
isActive: false,
|
||||
GridItemProps: { justifyContent: "center" },
|
||||
},
|
||||
*/
|
||||
];
|
||||
|
||||
export { LeaderboardGridCell };
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import { Box, CircularProgress, Stack, StackProps, Text, TextProps, useColorModeValue } from "@chakra-ui/react";
|
||||
import { boolean } from "boolean";
|
||||
import { useState } from "react";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
import { get } from "src/lib/api";
|
||||
import useSWR from "swr";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const MessageHeaderProps: TextProps = {
|
||||
fontSize: "xl",
|
||||
@@ -21,39 +20,24 @@ const MessageStackProps: StackProps = {
|
||||
interface MessageWithChildrenProps {
|
||||
id: string;
|
||||
depth?: number;
|
||||
maxDepth?: number;
|
||||
maxDepth: number;
|
||||
isOnlyChild?: boolean;
|
||||
}
|
||||
|
||||
export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
const childBackgroundColor = useColorModeValue("gray.200", "gray.700");
|
||||
const { id, depth = 0, maxDepth, isOnlyChild = true } = props;
|
||||
|
||||
const { id, depth, maxDepth, isOnlyChild = true } = props;
|
||||
const { isLoading, data: message } = useSWRImmutable<Message>(`/api/messages/${id}`, get);
|
||||
const { isLoading: isLoadingChildren, data: children = [] } = useSWRImmutable<Message[]>(
|
||||
`/api/messages/${id}/children`,
|
||||
get
|
||||
);
|
||||
|
||||
const [message, setMessage] = useState(null);
|
||||
const [children, setChildren] = useState(null);
|
||||
|
||||
const { isLoading } = useSWR(id ? `/api/messages/${id}` : null, get, {
|
||||
onSuccess: (data) => {
|
||||
setMessage(data);
|
||||
},
|
||||
onError: () => {
|
||||
setMessage(null);
|
||||
},
|
||||
});
|
||||
const { isLoading: isLoadingChildren } = useSWR(id ? `/api/messages/${id}/children` : null, get, {
|
||||
onSuccess: (data) => {
|
||||
setChildren(data);
|
||||
},
|
||||
onError: () => {
|
||||
setChildren(null);
|
||||
},
|
||||
});
|
||||
|
||||
const renderRecursive = maxDepth && ((depth && depth < maxDepth) || !depth);
|
||||
const isFirst = depth === 0 || !depth;
|
||||
const isFirstOrOnly = isFirst || boolean(isOnlyChild);
|
||||
const renderRecursive = depth < maxDepth || depth === 0;
|
||||
const isFirst = depth === 0;
|
||||
const isFirstOrOnly = isFirst || !!isOnlyChild;
|
||||
|
||||
if (isLoading || isLoadingChildren) {
|
||||
return <CircularProgress isIndeterminate />;
|
||||
@@ -73,15 +57,15 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
{children && Array.isArray(children) && children.length > 0 ? (
|
||||
renderRecursive ? (
|
||||
{children.length > 0 &&
|
||||
(renderRecursive ? (
|
||||
<Stack {...MessageStackProps}>
|
||||
<Box bg={childBackgroundColor} padding="4" borderRadius="xl">
|
||||
{children.map((item, idx) => (
|
||||
<Box flex="1" key={`recursiveMessageWChildren_${idx}`}>
|
||||
{children.map((item) => (
|
||||
<Box flex="1" key={`recursiveMessageWChildren_${item.id}`}>
|
||||
<MessageWithChildren
|
||||
id={item.id}
|
||||
depth={depth ? depth + 1 : 1}
|
||||
depth={depth + 1}
|
||||
maxDepth={maxDepth}
|
||||
isOnlyChild={children.length === 1 && isOnlyChild}
|
||||
/>
|
||||
@@ -110,10 +94,7 @@ export function MessageWithChildren(props: MessageWithChildrenProps) {
|
||||
</Box>
|
||||
</Stack>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
))}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { Select, SelectProps } from "@chakra-ui/react";
|
||||
import { forwardRef } from "react";
|
||||
import { ElementOf } from "src/types/utils";
|
||||
|
||||
export const roles = ["general", "admin", "banned"] as const;
|
||||
export type Role = ElementOf<typeof roles>;
|
||||
|
||||
type RoleSelectProps = Omit<SelectProps, "defaultValue"> & {
|
||||
defaultValue?: Role;
|
||||
value?: Role;
|
||||
};
|
||||
|
||||
export const RoleSelect = forwardRef<HTMLSelectElement, RoleSelectProps>((props, ref) => {
|
||||
return (
|
||||
<Select {...props} ref={ref}>
|
||||
{roles.map((role) => (
|
||||
<option value={role} key={role}>
|
||||
{role}
|
||||
</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
});
|
||||
|
||||
RoleSelect.displayName = "RoleSelect";
|
||||
@@ -1,4 +1,18 @@
|
||||
import { Box, Button, Flex, useColorMode } from "@chakra-ui/react";
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Flex,
|
||||
IconButton,
|
||||
Popover,
|
||||
PopoverArrow,
|
||||
PopoverBody,
|
||||
PopoverCloseButton,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
useColorMode,
|
||||
} from "@chakra-ui/react";
|
||||
import { InformationCircleIcon } from "@heroicons/react/20/solid";
|
||||
import { useId, useState } from "react";
|
||||
import { colors } from "src/styles/Theme/colors";
|
||||
|
||||
@@ -8,6 +22,17 @@ interface LabelRadioGroupProps {
|
||||
isEditable?: boolean;
|
||||
}
|
||||
|
||||
const label_messages: { [label: string]: { description: string; explanation: string[] } } = {
|
||||
spam: {
|
||||
description: "The message is spam?",
|
||||
explanation: [
|
||||
'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".',
|
||||
"This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.",
|
||||
"Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgements beyond that.",
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
export const LabelRadioGroup = (props: LabelRadioGroupProps) => {
|
||||
const [labelValues, setLabelValues] = useState<number[]>(Array.from({ length: props.labelIDs.length }).map(() => 0));
|
||||
const [interactionFlag, setInteractionFlag] = useState(false);
|
||||
@@ -17,7 +42,7 @@ export const LabelRadioGroup = (props: LabelRadioGroupProps) => {
|
||||
{props.labelIDs.map((labelId, idx) => (
|
||||
<LabelRadioItem
|
||||
key={idx}
|
||||
labelId={labelId}
|
||||
labelText={label_messages[labelId] || { description: labelId }}
|
||||
labelValue={labelValues[idx]}
|
||||
clickHandler={(newValue) => {
|
||||
const newState = labelValues.slice();
|
||||
@@ -45,7 +70,7 @@ interface ButtonState {
|
||||
}
|
||||
|
||||
interface LabelRadioItemProps {
|
||||
labelId: string;
|
||||
labelText: { description: string; explanation?: string[] };
|
||||
labelValue: number;
|
||||
clickHandler: (newVal: number) => unknown;
|
||||
states: ButtonState[];
|
||||
@@ -63,7 +88,27 @@ const LabelRadioItem = (props: LabelRadioItemProps) => {
|
||||
<Box data-cy="label-group-item" data-label-type="radio">
|
||||
<label className="text-sm" htmlFor={id}>
|
||||
{/* TODO: display real text instead of just the id */}
|
||||
<span className={labelTextClass}>{props.labelId}</span>
|
||||
<span className={labelTextClass}>{props.labelText.description}</span>
|
||||
{props.labelText.explanation ? (
|
||||
<Popover>
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
aria-label="explanation"
|
||||
variant="link"
|
||||
icon={<InformationCircleIcon className="h-5 w-5" />}
|
||||
></IconButton>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverArrow />
|
||||
<PopoverCloseButton />
|
||||
<PopoverBody>
|
||||
{props.labelText.explanation.map((paragraph, idx) => (
|
||||
<Text key={idx}>{paragraph}</Text>
|
||||
))}
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
) : null}
|
||||
</label>
|
||||
<Flex direction="row" gap={6} justify="center">
|
||||
{props.states.map((item, idx) => (
|
||||
|
||||
@@ -66,7 +66,7 @@ export const LabelTask = ({
|
||||
</Box>
|
||||
)}
|
||||
</>
|
||||
{valid_labels.length === 1 ? (
|
||||
{task.mode === "simple" ? (
|
||||
<LabelRadioGroup labelIDs={task.valid_labels} isEditable={isEditable} onChange={onSliderChange} />
|
||||
) : (
|
||||
<LabelSliderGroup labelIDs={task.valid_labels} isEditable={isEditable} onChange={onSliderChange} />
|
||||
|
||||
@@ -27,7 +27,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
const replyContent = useRef<TaskContent>(null);
|
||||
const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
|
||||
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === task.type);
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === task.type && taskType.mode === task.mode);
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, {
|
||||
onSuccess: async () => {
|
||||
|
||||
@@ -11,6 +11,7 @@ export interface TaskInfo {
|
||||
category: TaskCategory;
|
||||
pathname: string;
|
||||
type: string;
|
||||
mode?: string;
|
||||
overview?: string;
|
||||
instruction?: string;
|
||||
update_type: string;
|
||||
@@ -90,7 +91,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
unchanged_title: "Order Unchanged",
|
||||
unchanged_message: "You have not changed the order of the prompts. Are you sure you would like to continue?",
|
||||
},
|
||||
// label
|
||||
// label (fuill)
|
||||
{
|
||||
label: "Label Initial Prompt",
|
||||
desc: "Provide labels for a prompt.",
|
||||
@@ -98,6 +99,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
pathname: "/label/label_initial_prompt",
|
||||
overview: "Provide labels for the following prompt",
|
||||
type: "label_initial_prompt",
|
||||
mode: "full",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
@@ -107,6 +109,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
pathname: "/label/label_prompter_reply",
|
||||
overview: "Given the following discussion, provide labels for the final promp",
|
||||
type: "label_prompter_reply",
|
||||
mode: "full",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
@@ -116,6 +119,38 @@ export const TaskTypes: TaskInfo[] = [
|
||||
pathname: "/label/label_assistant_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt.",
|
||||
type: "label_assistant_reply",
|
||||
mode: "full",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
// label (simple)
|
||||
{
|
||||
label: "Classify Initial Prompt",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
overview: "Read the following prompt and then answer the question about it.",
|
||||
type: "label_initial_prompt",
|
||||
mode: "simple",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
label: "Classify Prompter Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
overview: "Read the following conversation and then answer the question about the last prompt in the disscusion.",
|
||||
type: "label_prompter_reply",
|
||||
mode: "simple",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
{
|
||||
label: "Classify Assistant Reply",
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
overview: "Read the following conversation and then answer the question about the last prompt in the disscusion.",
|
||||
type: "label_assistant_reply",
|
||||
mode: "simple",
|
||||
update_type: "text_labels",
|
||||
},
|
||||
];
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
Th,
|
||||
Thead,
|
||||
Tr,
|
||||
useToast,
|
||||
} from "@chakra-ui/react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
@@ -18,26 +19,78 @@ import { get } from "src/lib/api";
|
||||
import type { User } from "src/types/Users";
|
||||
import useSWR from "swr";
|
||||
|
||||
interface Pagination {
|
||||
/**
|
||||
* The user's `display_name` used for pagination.
|
||||
*/
|
||||
cursor: string;
|
||||
|
||||
/**
|
||||
* The pagination direction.
|
||||
*/
|
||||
direction: "forward" | "back";
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches users from the users api route and then presents them in a simple Chakra table.
|
||||
*/
|
||||
const UsersCell = () => {
|
||||
const [pageIndex, setPageIndex] = useState(0);
|
||||
const toast = useToast();
|
||||
const [pagination, setPagination] = useState<Pagination>({ cursor: "", direction: "forward" });
|
||||
const [users, setUsers] = useState<User[]>([]);
|
||||
|
||||
// Fetch and save the users.
|
||||
// This follows useSWR's recommendation for simple pagination:
|
||||
// https://swr.vercel.app/docs/pagination#when-to-use-useswr
|
||||
useSWR(`/api/admin/users?pageIndex=${pageIndex}`, get, {
|
||||
onSuccess: setUsers,
|
||||
useSWR(`/api/admin/users?direction=${pagination.direction}&cursor=${pagination.cursor}`, get, {
|
||||
onSuccess: (data) => {
|
||||
// When no more users can be found, trigger a toast to indicate why no
|
||||
// changes have taken place. We have to maintain a non-empty set of
|
||||
// users otherwise we can't paginate using a cursor (since we've lost the
|
||||
// cursor).
|
||||
if (data.length === 0) {
|
||||
toast({
|
||||
title: "No more users",
|
||||
status: "warning",
|
||||
duration: 1000,
|
||||
isClosable: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
setUsers(data);
|
||||
},
|
||||
});
|
||||
|
||||
const toPreviousPage = () => {
|
||||
setPageIndex(Math.max(0, pageIndex - 1));
|
||||
if (users.length >= 0) {
|
||||
setPagination({
|
||||
cursor: users[0].display_name,
|
||||
direction: "back",
|
||||
});
|
||||
} else {
|
||||
toast({
|
||||
title: "Can not paginate when no users are found",
|
||||
status: "warning",
|
||||
duration: 1000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const toNextPage = () => {
|
||||
setPageIndex(pageIndex + 1);
|
||||
if (users.length >= 0) {
|
||||
setPagination({
|
||||
cursor: users[users.length - 1].display_name,
|
||||
direction: "forward",
|
||||
});
|
||||
} else {
|
||||
toast({
|
||||
title: "Can not paginate when no users are found",
|
||||
status: "warning",
|
||||
duration: 1000,
|
||||
isClosable: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Present users in a naive table.
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import type { NextApiRequest, NextApiResponse } from "next";
|
||||
import { getToken, JWT } from "next-auth/jwt";
|
||||
import { Role } from "src/components/RoleSelect";
|
||||
|
||||
/**
|
||||
* Wraps any API Route handler and verifies that the user does not have the
|
||||
* specified role. Returns a 403 if they do, otherwise runs the handler.
|
||||
*/
|
||||
const withoutRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse, arg2: JWT) => void) => {
|
||||
const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, arg2: JWT) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role === role) {
|
||||
@@ -20,7 +21,7 @@ const withoutRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApi
|
||||
* Wraps any API Route handler and verifies that the user has the appropriate
|
||||
* role before running the handler. Returns a 403 otherwise.
|
||||
*/
|
||||
const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => {
|
||||
return async (req: NextApiRequest, res: NextApiResponse) => {
|
||||
const token = await getToken({ req });
|
||||
if (!token || token.role !== role) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { JWT } from "next-auth/jwt";
|
||||
import type { Message } from "src/types/Conversation";
|
||||
import { LeaderboardReply, LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
import type { BackendUser } from "src/types/Users";
|
||||
|
||||
export class OasstError {
|
||||
@@ -157,10 +158,27 @@ export class OasstApiClient {
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the `max_count` `BackendUser`s stored by the backend.
|
||||
* Returns the set of `BackendUser`s stored by the backend.
|
||||
*
|
||||
* @param {number} max_count - The maximum number of users to fetch.
|
||||
* @param {string} cursor - The user's `display_name` to use when paginating.
|
||||
* @param {boolean} isForward - If true and `cursor` is not empty, pages
|
||||
* forward. If false and `cursor` is not empty, pages backwards.
|
||||
* @returns {Promise<BackendUser[]>} A Promise that returns an array of `BackendUser` objects.
|
||||
*/
|
||||
async fetch_users(max_count: number): Promise<BackendUser[]> {
|
||||
return this.get(`/api/v1/frontend_users/?max_count=${max_count}`);
|
||||
async fetch_users(max_count: number, cursor: string, isForward: boolean): Promise<BackendUser[]> {
|
||||
const params = new URLSearchParams();
|
||||
params.append("max_count", max_count.toString());
|
||||
|
||||
// The backend API uses different query paramters depending on the
|
||||
// pagination direction but they both take the same cursor value.
|
||||
// Depending on direction, pick the right query param.
|
||||
if (cursor !== "") {
|
||||
params.append(isForward ? "gt" : "lt", cursor);
|
||||
}
|
||||
const BASE_URL = `/api/v1/frontend_users`;
|
||||
const url = `${BASE_URL}/?${params.toString()}`;
|
||||
return this.get(url);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -187,8 +205,8 @@ export class OasstApiClient {
|
||||
/**
|
||||
* Returns the current leaderboard ranking.
|
||||
*/
|
||||
async fetch_leaderboard(): Promise<any> {
|
||||
return this.get(`/api/v1/experimental/leaderboards/create/assistant`);
|
||||
async fetch_leaderboard(time_frame: LeaderboardTimeFrame): Promise<LeaderboardReply> {
|
||||
return this.get(`/api/v1/leaderboards/${time_frame}`);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import { SessionProvider } from "next-auth/react";
|
||||
import { FlagsProvider } from "react-feature-flags";
|
||||
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
|
||||
import flags from "src/flags";
|
||||
import { SWRConfig, SWRConfiguration } from "swr";
|
||||
|
||||
import { Chakra, getServerSideProps } from "../styles/Chakra";
|
||||
|
||||
@@ -13,6 +14,11 @@ type AppPropsWithLayout = AppProps & {
|
||||
Component: NextPageWithLayout;
|
||||
};
|
||||
|
||||
const swrConfig: SWRConfiguration = {
|
||||
revalidateOnFocus: false,
|
||||
revalidateOnMount: true,
|
||||
};
|
||||
|
||||
function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: AppPropsWithLayout) {
|
||||
const getLayout = Component.getLayout ?? getDefaultLayout;
|
||||
const page = getLayout(<Component {...pageProps} />);
|
||||
@@ -20,7 +26,9 @@ function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: App
|
||||
return (
|
||||
<FlagsProvider value={flags}>
|
||||
<Chakra cookies={cookies}>
|
||||
<SessionProvider session={session}>{page}</SessionProvider>
|
||||
<SWRConfig value={swrConfig}>
|
||||
<SessionProvider session={session}>{page}</SessionProvider>
|
||||
</SWRConfig>
|
||||
</Chakra>
|
||||
</FlagsProvider>
|
||||
);
|
||||
|
||||
@@ -1,17 +1,28 @@
|
||||
import { Button, Container, FormControl, FormLabel, Input, Select, Stack, useToast } from "@chakra-ui/react";
|
||||
import { Field, Form, Formik } from "formik";
|
||||
import { Button, Card, CardBody, Container, FormControl, FormLabel, Input, Stack, useToast } from "@chakra-ui/react";
|
||||
import { InferGetServerSidePropsType } from "next";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useEffect } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { getAdminLayout } from "src/components/Layout";
|
||||
import { Role, RoleSelect } from "src/components/RoleSelect";
|
||||
import { UserMessagesCell } from "src/components/UserMessagesCell";
|
||||
import { post } from "src/lib/api";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import useSWRMutation from "swr/mutation";
|
||||
|
||||
const ManageUser = ({ user }) => {
|
||||
interface UserForm {
|
||||
user_id: string;
|
||||
id: string;
|
||||
auth_method: string;
|
||||
display_name: string;
|
||||
role: Role;
|
||||
notes: string;
|
||||
}
|
||||
|
||||
const ManageUser = ({ user }: InferGetServerSidePropsType<typeof getServerSideProps>) => {
|
||||
const toast = useToast();
|
||||
const router = useRouter();
|
||||
const { data: session, status } = useSession();
|
||||
@@ -51,6 +62,10 @@ const ManageUser = ({ user }) => {
|
||||
},
|
||||
});
|
||||
|
||||
const { register, handleSubmit } = useForm<UserForm>({
|
||||
defaultValues: user,
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
@@ -61,50 +76,31 @@ const ManageUser = ({ user }) => {
|
||||
/>
|
||||
</Head>
|
||||
<Stack gap="4">
|
||||
<Container className="oa-basic-theme">
|
||||
<Formik
|
||||
initialValues={user}
|
||||
onSubmit={(values) => {
|
||||
trigger(values);
|
||||
}}
|
||||
>
|
||||
<Form>
|
||||
<Field name="user_id" type="hidden" />
|
||||
<Field name="id" type="hidden" />
|
||||
<Field name="auth_method" type="hidden" />
|
||||
<Field name="display_name">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Display Name</FormLabel>
|
||||
<Input {...field} isDisabled />
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
<Field name="role">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Role</FormLabel>
|
||||
<Select {...field}>
|
||||
<option value="banned">Banned</option>
|
||||
<option value="general">General</option>
|
||||
<option value="admin">Admin</option>
|
||||
</Select>
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
<Field name="notes">
|
||||
{({ field }) => (
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<Input {...field} />
|
||||
</FormControl>
|
||||
)}
|
||||
</Field>
|
||||
<Button mt={4} type="submit">
|
||||
Update
|
||||
</Button>
|
||||
</Form>
|
||||
</Formik>
|
||||
<Container>
|
||||
<Card>
|
||||
<CardBody>
|
||||
<form onSubmit={handleSubmit((data) => trigger(data))}>
|
||||
<input type="hidden" {...register("user_id")}></input>
|
||||
<input type="hidden" {...register("id")}></input>
|
||||
<input type="hidden" {...register("auth_method")}></input>
|
||||
<FormControl>
|
||||
<FormLabel>Display Name</FormLabel>
|
||||
<Input {...register("display_name")} isDisabled />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>Role</FormLabel>
|
||||
<RoleSelect {...register("role")}></RoleSelect>
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>Notes</FormLabel>
|
||||
<Input {...register("notes")} />
|
||||
</FormControl>
|
||||
<Button mt={4} type="submit">
|
||||
Update
|
||||
</Button>
|
||||
</form>
|
||||
</CardBody>
|
||||
</Card>
|
||||
</Container>
|
||||
<UserMessagesCell path={`/api/admin/user_messages?user=${user.user_id}`} />
|
||||
</Stack>
|
||||
@@ -125,7 +121,7 @@ export async function getServerSideProps({ query }) {
|
||||
});
|
||||
const user = {
|
||||
...backend_user,
|
||||
role: local_user?.role || "general",
|
||||
role: (local_user?.role || "general") as Role,
|
||||
};
|
||||
return {
|
||||
props: {
|
||||
|
||||
@@ -2,17 +2,27 @@ import { withRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* The number of users to fetch in a single request. Could later be a query parameter.
|
||||
*/
|
||||
const PAGE_SIZE = 20;
|
||||
|
||||
/**
|
||||
* Returns a list of user results from the database when the requesting user is
|
||||
* a logged in admin.
|
||||
*
|
||||
* This takes two query params:
|
||||
* - `cursor`: A string representing a user's `display_name`.
|
||||
* - `direction`: Either "forward" or "backward" representing the pagination
|
||||
* direction.
|
||||
*/
|
||||
const handler = withRole("admin", async (req, res) => {
|
||||
// TODO(#673): Update this to support pagination.
|
||||
const { cursor, direction } = req.query;
|
||||
|
||||
// First, get all the users according to the backend.
|
||||
const all_users = await oasstApiClient.fetch_users(20);
|
||||
const all_users = await oasstApiClient.fetch_users(PAGE_SIZE, cursor as string, direction === "forward");
|
||||
|
||||
// Next, get all the users stored in the web's auth datbase to fetch their role.
|
||||
// Next, get all the users stored in the web's auth database to fetch their role.
|
||||
const local_user_ids = all_users.map(({ id }) => id);
|
||||
const local_users = await prisma.user.findMany({
|
||||
where: {
|
||||
|
||||
@@ -2,12 +2,13 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter";
|
||||
import { boolean } from "boolean";
|
||||
import type { AuthOptions } from "next-auth";
|
||||
import NextAuth from "next-auth";
|
||||
import { Provider } from "next-auth/providers";
|
||||
import CredentialsProvider from "next-auth/providers/credentials";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
const providers = [];
|
||||
const providers: Provider[] = [];
|
||||
|
||||
// Register an email magic link auth method.
|
||||
providers.push(
|
||||
@@ -39,11 +40,13 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development")
|
||||
name: "Debug Credentials",
|
||||
credentials: {
|
||||
username: { label: "Username", type: "text" },
|
||||
role: { label: "Role", type: "text" },
|
||||
},
|
||||
async authorize(credentials) {
|
||||
const user = {
|
||||
id: credentials.username,
|
||||
name: credentials.username,
|
||||
role: credentials.role,
|
||||
};
|
||||
// save the user to the database
|
||||
await prisma.user.upsert({
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import { oasstApiClient } from "src/lib/oasst_api_client";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
/**
|
||||
* Returns the set of valid labels that can be applied to messages.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res) => {
|
||||
const { leaderboard } = await oasstApiClient.fetch_leaderboard();
|
||||
res.status(200).json(
|
||||
leaderboard.map(({ display_name, ranking, score }) => ({
|
||||
display_name,
|
||||
ranking,
|
||||
score,
|
||||
}))
|
||||
);
|
||||
const time_frame = (req.query.time_frame as LeaderboardTimeFrame) || LeaderboardTimeFrame.day;
|
||||
const { leaderboard } = await oasstApiClient.fetch_leaderboard(time_frame);
|
||||
res.status(200).json(leaderboard);
|
||||
});
|
||||
|
||||
export default handler;
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import { Button, Input, Stack } from "@chakra-ui/react";
|
||||
import { Button, ButtonProps, Input, Stack, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { GetServerSideProps } from "next";
|
||||
import Head from "next/head";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { getCsrfToken, getProviders, signIn } from "next-auth/react";
|
||||
import { ClientSafeProvider, getProviders, signIn } from "next-auth/react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
import { RoleSelect } from "src/components/RoleSelect";
|
||||
|
||||
export type SignInErrorTypes =
|
||||
| "Signin"
|
||||
@@ -37,8 +39,11 @@ const errorMessages: Record<SignInErrorTypes, string> = {
|
||||
default: "Unable to sign in.",
|
||||
};
|
||||
|
||||
interface SigninProps {
|
||||
providers: Awaited<ReturnType<typeof getProviders>>;
|
||||
}
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
function Signin({ csrfToken, providers }) {
|
||||
function Signin({ providers }: SigninProps) {
|
||||
const router = useRouter();
|
||||
const { discord, email, github, credentials } = providers;
|
||||
const emailEl = useRef(null);
|
||||
@@ -60,18 +65,10 @@ function Signin({ csrfToken, providers }) {
|
||||
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
|
||||
};
|
||||
|
||||
const debugUsernameEl = useRef(null);
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
signIn(credentials.id, { callbackUrl: "/dashboard", username: debugUsernameEl.current.value });
|
||||
}
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
|
||||
const buttonBgColor = colorMode === "light" ? "#2563eb" : "#2563eb";
|
||||
|
||||
const buttonColorScheme = colorMode === "light" ? "blue" : "dark-blue-btn";
|
||||
|
||||
return (
|
||||
<div className={bgColorClass}>
|
||||
<Head>
|
||||
@@ -80,17 +77,7 @@ function Signin({ csrfToken, providers }) {
|
||||
</Head>
|
||||
<AuthLayout>
|
||||
<Stack spacing="2">
|
||||
{credentials && (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-600 rounded-md p-4 relative">
|
||||
<span className={`text-orange-600 absolute -top-3 left-5 ${bgColorClass} px-1`}>For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<Button size={"lg"} leftIcon={<FaBug />} colorScheme={buttonColorScheme} color="white" type="submit">
|
||||
Continue with Debug User
|
||||
</Button>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
{credentials && <DebugSigninForm credentials={credentials} bgColorClass={bgColorClass} />}
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<Stack>
|
||||
@@ -102,16 +89,9 @@ function Signin({ csrfToken, providers }) {
|
||||
placeholder="Email Address"
|
||||
ref={emailEl}
|
||||
/>
|
||||
<Button
|
||||
data-cy="signin-email-button"
|
||||
size={"lg"}
|
||||
leftIcon={<FaEnvelope />}
|
||||
type="submit"
|
||||
colorScheme={buttonColorScheme}
|
||||
color="white"
|
||||
>
|
||||
<SigninButton data-cy="signin-email-button" leftIcon={<FaEnvelope />}>
|
||||
Continue with Email
|
||||
</Button>
|
||||
</SigninButton>
|
||||
</Stack>
|
||||
</form>
|
||||
)}
|
||||
@@ -179,13 +159,49 @@ Signin.getLayout = (page) => (
|
||||
|
||||
export default Signin;
|
||||
|
||||
export async function getServerSideProps() {
|
||||
const csrfToken = await getCsrfToken();
|
||||
const SigninButton = (props: ButtonProps) => {
|
||||
const buttonColorScheme = useColorModeValue("blue", "dark-blue-btn");
|
||||
|
||||
return (
|
||||
<Button
|
||||
size={"lg"}
|
||||
leftIcon={<FaEnvelope />}
|
||||
type="submit"
|
||||
colorScheme={buttonColorScheme}
|
||||
color="white"
|
||||
{...props}
|
||||
></Button>
|
||||
);
|
||||
};
|
||||
|
||||
const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSafeProvider; bgColorClass: string }) => {
|
||||
const debugUsernameEl = useRef(null);
|
||||
const roleRef = useRef(null);
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
signIn(credentials.id, {
|
||||
callbackUrl: "/dashboard",
|
||||
username: debugUsernameEl.current.value,
|
||||
role: roleRef.current.value,
|
||||
});
|
||||
}
|
||||
return (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-600 rounded-md p-4 relative">
|
||||
<span className={`text-orange-600 absolute -top-3 left-5 ${bgColorClass} px-1`}>For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<RoleSelect defaultValue={"general"} ref={roleRef}></RoleSelect>
|
||||
<SigninButton leftIcon={<FaBug />}>Continue with Debug User</SigninButton>
|
||||
</Stack>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
export const getServerSideProps: GetServerSideProps<SigninProps> = async () => {
|
||||
const providers = await getProviders();
|
||||
return {
|
||||
props: {
|
||||
csrfToken,
|
||||
providers,
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { Box, Heading } from "@chakra-ui/react";
|
||||
import { Box, Heading, Tabs, TabList, TabPanels, Tab, TabPanel } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { LeaderboardGridCell } from "src/components/LeaderboardGridCell";
|
||||
import { LeaderboardTimeFrame } from "src/types/Leaderboard";
|
||||
|
||||
const Leaderboard = () => {
|
||||
return (
|
||||
@@ -14,7 +15,29 @@ const Leaderboard = () => {
|
||||
<Heading fontSize="2xl" fontWeight="bold" pb="4">
|
||||
Leaderboard
|
||||
</Heading>
|
||||
<LeaderboardGridCell />
|
||||
<Tabs isFitted isLazy>
|
||||
<TabList>
|
||||
<Tab>Daily</Tab>
|
||||
<Tab>Weekly</Tab>
|
||||
<Tab>Monthly</Tab>
|
||||
<Tab>Overall</Tab>
|
||||
</TabList>
|
||||
|
||||
<TabPanels>
|
||||
<TabPanel p="0">
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.day} />
|
||||
</TabPanel>
|
||||
<TabPanel p="0">
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.week} />
|
||||
</TabPanel>
|
||||
<TabPanel p="0">
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.month} />
|
||||
</TabPanel>
|
||||
<TabPanel p="0">
|
||||
<LeaderboardGridCell timeFrame={LeaderboardTimeFrame.total} />
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</Tabs>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
import { Box, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useState } from "react";
|
||||
import { getDashboardLayout } from "src/components/Layout";
|
||||
import { MessageLoading } from "src/components/Loading/MessageLoading";
|
||||
import { MessageTableEntry } from "src/components/Messages/MessageTableEntry";
|
||||
import { MessageWithChildren } from "src/components/Messages/MessageWithChildren";
|
||||
import { get } from "src/lib/api";
|
||||
import useSWR from "swr";
|
||||
import { Message } from "src/types/Conversation";
|
||||
import useSWRImmutable from "swr/immutable";
|
||||
|
||||
const MessageDetail = ({ id }) => {
|
||||
const MessageDetail = ({ id }: { id: string }) => {
|
||||
const backgroundColor = useColorModeValue("white", "gray.800");
|
||||
const [parent, setParent] = useState(null);
|
||||
|
||||
const { isLoading: isLoadingParent } = useSWR(id ? `/api/messages/${id}/parent` : null, get, {
|
||||
onSuccess: (data) => {
|
||||
setParent(data);
|
||||
},
|
||||
onError: () => {
|
||||
setParent(null);
|
||||
},
|
||||
});
|
||||
const { isLoading: isLoadingParent, data: parent } = useSWRImmutable<Message>(`/api/messages/${id}/parent`, get);
|
||||
|
||||
if (isLoadingParent) {
|
||||
return <MessageLoading />;
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
import { cardAnatomy } from "@chakra-ui/anatomy";
|
||||
import { createMultiStyleConfigHelpers } from "@chakra-ui/react";
|
||||
|
||||
const { definePartsStyle, defineMultiStyleConfig } = createMultiStyleConfigHelpers(cardAnatomy.keys);
|
||||
|
||||
export const cardTheme = defineMultiStyleConfig({
|
||||
baseStyle: definePartsStyle(({ colorMode }) => {
|
||||
const isLightMode = colorMode === "light";
|
||||
return {
|
||||
container: {
|
||||
backgroundColor: isLightMode ? "white" : "gray.700",
|
||||
},
|
||||
header: {},
|
||||
body: {
|
||||
padding: 6,
|
||||
},
|
||||
footer: {},
|
||||
};
|
||||
}),
|
||||
variants: {
|
||||
elevated: definePartsStyle({
|
||||
container: {
|
||||
borderRadius: "xl",
|
||||
},
|
||||
}),
|
||||
},
|
||||
});
|
||||
@@ -2,6 +2,7 @@ import { type ThemeConfig, extendTheme } from "@chakra-ui/react";
|
||||
import { Styles } from "@chakra-ui/theme-tools";
|
||||
|
||||
import { colors } from "./colors";
|
||||
import { cardTheme } from "./components/Card";
|
||||
import { containerTheme } from "./components/Container";
|
||||
|
||||
const config: ThemeConfig = {
|
||||
@@ -12,6 +13,7 @@ const config: ThemeConfig = {
|
||||
|
||||
const components = {
|
||||
Container: containerTheme,
|
||||
Card: cardTheme,
|
||||
};
|
||||
|
||||
const breakpoints = {
|
||||
|
||||
@@ -3,3 +3,40 @@ export interface LeaderboardEntry {
|
||||
ranking: number;
|
||||
score: number;
|
||||
}
|
||||
|
||||
export const enum LeaderboardTimeFrame {
|
||||
day = "day",
|
||||
week = "week",
|
||||
month = "month",
|
||||
total = "total",
|
||||
}
|
||||
export interface LeaderboardReply {
|
||||
time_frame: LeaderboardTimeFrame;
|
||||
leaderboard: LeaderboardEntity[];
|
||||
}
|
||||
|
||||
export interface LeaderboardEntity {
|
||||
rank: number;
|
||||
user_id: string;
|
||||
username: string;
|
||||
auth_method: string;
|
||||
display_name: string;
|
||||
leader_score: number;
|
||||
base_date: string;
|
||||
modified_date: string;
|
||||
prompts: number;
|
||||
replies_assistant: number;
|
||||
replies_prompter: number;
|
||||
labels_simple: number;
|
||||
labels_full: number;
|
||||
rankings_total: number;
|
||||
rankings_good: number;
|
||||
accepted_prompts: number;
|
||||
accepted_replies_assistant: number;
|
||||
accepted_replies_prompter: number;
|
||||
reply_ranked_1: number;
|
||||
reply_ranked_2: number;
|
||||
reply_ranked_3: number;
|
||||
streak_last_day_date: number | null;
|
||||
streak_days: number | null;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
// https://github.com/ts-essentials/ts-essentials/blob/25cae45c162f8784e3cdae8f43783d0c66370a57/lib/types.ts#L437
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export type ElementOf<T extends readonly any[]> = T extends readonly (infer ET)[] ? ET : never;
|
||||
Reference in New Issue
Block a user