mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' into 766_admin_enhancement
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
name: Build OASST Postgres image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- docker/oasst-postgres/**
|
||||
|
||||
jobs:
|
||||
build-postgres:
|
||||
uses: ./.github/workflows/docker-build.yaml
|
||||
with:
|
||||
image-name: oasst-postgres
|
||||
context: ./docker/oasst-postgres
|
||||
dockerfile: docker/oasst-postgres/Dockerfile
|
||||
build-args: ""
|
||||
@@ -34,8 +34,13 @@ jobs:
|
||||
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
|
||||
WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }}
|
||||
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
|
||||
S3_REGION: ${{ secrets.S3_REGION }}
|
||||
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
|
||||
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
|
||||
MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
|
||||
MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
|
||||
GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
|
||||
SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
|
||||
+20
-25
@@ -57,8 +57,9 @@
|
||||
- name: Create postgres containers
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
|
||||
image: postgres:15
|
||||
image: ghcr.io/laion-ai/open-assistant/oasst-postgres
|
||||
state: started
|
||||
pull: true
|
||||
recreate: "{{ (stack_name == 'dev') | bool }}"
|
||||
restart_policy: always
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
@@ -66,6 +67,13 @@
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: "{{ postgres_password }}"
|
||||
POSTGRES_DB: postgres
|
||||
S3_BUCKET_NAME:
|
||||
"{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
|
||||
AWS_ACCESS_KEY_ID:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
|
||||
AWS_DEFAULT_REGION: "{{ lookup('ansible.builtin.env', 'S3_REGION') }}"
|
||||
volumes:
|
||||
- "oasst-{{ stack_name }}-postgres-{{ item.name
|
||||
}}:/var/lib/postgresql/data"
|
||||
@@ -78,29 +86,6 @@
|
||||
- name: backend
|
||||
- name: web
|
||||
|
||||
- name: Copy pgbackrest.conf to managed node
|
||||
ansible.builtin.copy:
|
||||
src: ./pgbackrest.conf
|
||||
dest: "./{{ stack_name }}/pgbackrest.conf"
|
||||
mode: 0644
|
||||
|
||||
- name: Create pgbackrest container
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-pgbackrest"
|
||||
image: woblerr/pgbackrest:2.43
|
||||
state: "{{ 'stopped' if stack_name == 'production' else 'absent' }}"
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
volumes:
|
||||
- "./{{ stack_name }}/pgbackrest.conf:/etc/pgbackrest/pgbackrest.conf"
|
||||
- "oasst-{{ stack_name }}-postgres-backend:/var/lib/postgresql/data"
|
||||
env:
|
||||
PGBACKREST_REPO1_S3_BUCKET:
|
||||
"{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
|
||||
PGBACKREST_REPO1_S3_KEY:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
|
||||
PGBACKREST_REPO1_S3_KEY_SECRET:
|
||||
"{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
|
||||
|
||||
- name: Run the oasst oasst-backend
|
||||
community.docker.docker_container:
|
||||
name: "oasst-{{ stack_name }}-backend"
|
||||
@@ -122,8 +107,18 @@
|
||||
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
"{{ lookup('ansible.builtin.env', 'SKIP_TOXICITY_CALCULATION') |
|
||||
default('true', true) }}"
|
||||
OFFICIAL_WEB_API_KEY: "{{ web_api_key }}"
|
||||
TREE_MANAGER__MAX_ACTIVE_TREES:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') |
|
||||
default('10', true) }}"
|
||||
TREE_MANAGER__MAX_TREE_DEPTH:
|
||||
"{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5',
|
||||
true) }}"
|
||||
TREE_MANAGER__GOAL_TREE_SIZE:
|
||||
"{{ lookup('ansible.builtin.env', 'GOAL_TREE_SIZE') | default('15',
|
||||
true) }}"
|
||||
ports:
|
||||
- "{{ backend_port }}:8080"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
pg1-path=/var/lib/postgresql/data
|
||||
|
||||
[global]
|
||||
repo1-retention-full=3 # keep last 3 backups
|
||||
repo1-retention-full=3
|
||||
repo1-type=s3
|
||||
repo1-path=/oasst-prod
|
||||
repo1-s3-region=us-east-1
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""switch to timestamp with tz
|
||||
|
||||
Revision ID: 7f0a28a156f4
|
||||
Revises: 0964ac95170d
|
||||
Create Date: 2023-01-19 21:53:01.107137
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7f0a28a156f4"
|
||||
down_revision = "0964ac95170d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(table_name="user_stats", column_name="modified_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="user_stats", column_name="base_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="journal_integration", column_name="last_run", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="message_embedding", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="message_reaction", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="message_toxicity", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="message", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="task", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="task", column_name="expiry_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="text_labels", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
op.alter_column(table_name="user", column_name="created_date", type_=sa.DateTime(timezone=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(table_name="user_stats", column_name="modified_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="user_stats", column_name="base_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="journal_integration", column_name="last_run", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="message_embedding", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="message_reaction", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="message_toxicity", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="message", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="task", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="task", column_name="expiry_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="text_labels", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
op.alter_column(table_name="user", column_name="created_date", type_=sa.DateTime(timezone=False))
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,7 +1,17 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import Settings, settings
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from oasst_shared.utils import ScopeTimer, unaware_to_utc
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -13,7 +23,7 @@ class CreateApiClientRequest(pydantic.BaseModel):
|
||||
admin_email: str | None = None
|
||||
|
||||
|
||||
@router.post("/api_client")
|
||||
@router.post("/api_client", response_model=str)
|
||||
async def create_api_client(
|
||||
request: CreateApiClientRequest,
|
||||
root_token: str = Depends(deps.get_root_token),
|
||||
@@ -29,3 +39,125 @@ async def create_api_client(
|
||||
)
|
||||
logger.info(f"Created api_client with key {api_client.api_key}")
|
||||
return api_client.api_key
|
||||
|
||||
|
||||
@router.get("/backend_settings/full", response_model=Settings)
|
||||
async def get_backend_settings_full(api_client: ApiClient = Depends(deps.get_trusted_api_client)) -> Settings:
|
||||
logger.info(
|
||||
f"Backend settings requested by trusted api_client {api_client.id} (admin_email: {api_client.admin_email}, frontend_type: {api_client.frontend_type})"
|
||||
)
|
||||
return settings
|
||||
|
||||
|
||||
class PublicSettings(pydantic.BaseModel):
|
||||
"""Subset of backend settings which can be retrieved by untrusted API clients."""
|
||||
|
||||
PROJECT_NAME: str
|
||||
API_V1_STR: str
|
||||
DEBUG_USE_SEED_DATA: bool
|
||||
DEBUG_ALLOW_SELF_LABELING: bool
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool
|
||||
DEBUG_DATABASE_ECHO: bool
|
||||
USER_STATS_INTERVAL_DAY: int
|
||||
USER_STATS_INTERVAL_WEEK: int
|
||||
USER_STATS_INTERVAL_MONTH: int
|
||||
USER_STATS_INTERVAL_TOTAL: int
|
||||
|
||||
|
||||
@router.get("/backend_settings/public", response_model=PublicSettings)
|
||||
async def get_backend_settings_public(api_client: ApiClient = Depends(deps.get_api_client)) -> PublicSettings:
|
||||
return PublicSettings(**settings.dict())
|
||||
|
||||
|
||||
class PurgeResultModel(pydantic.BaseModel):
|
||||
before: SystemStats
|
||||
after: SystemStats
|
||||
preview: bool
|
||||
duration: float
|
||||
|
||||
|
||||
@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
|
||||
async def purge_user(
|
||||
user_id: UUID,
|
||||
preview: bool = True,
|
||||
ban: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
|
||||
def purge_tx(session: deps.Session) -> tuple[User, SystemStats, SystemStats]:
|
||||
pr = PromptRepository(session, api_client)
|
||||
|
||||
stats_before = pr.get_stats()
|
||||
|
||||
user = pr.user_repository.get_user(user_id)
|
||||
tm = TreeManager(session, pr)
|
||||
tm.purge_user(user_id=user_id, ban=ban)
|
||||
|
||||
session.expunge(user)
|
||||
return user, stats_before, pr.get_stats()
|
||||
|
||||
timer = ScopeTimer()
|
||||
user, before, after = purge_tx()
|
||||
timer.stop()
|
||||
|
||||
if preview:
|
||||
logger.info(
|
||||
f"PURGE USER PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
|
||||
@router.post("/purge_user/{user_id}/messages", response_model=PurgeResultModel)
|
||||
async def purge_user_messages(
|
||||
user_id: UUID,
|
||||
purge_initial_prompts: bool = False,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
preview: bool = True,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> str:
|
||||
assert api_client.trusted
|
||||
|
||||
min_date = unaware_to_utc(min_date)
|
||||
max_date = unaware_to_utc(max_date)
|
||||
|
||||
@managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
|
||||
def purge_user_messages_tx(session: deps.Session):
|
||||
pr = PromptRepository(session, api_client)
|
||||
|
||||
stats_before = pr.get_stats()
|
||||
|
||||
user = pr.user_repository.get_user(user_id)
|
||||
|
||||
tm = TreeManager(session, pr)
|
||||
tm.purge_user_messages(
|
||||
user_id, purge_initial_prompts=purge_initial_prompts, min_date=min_date, max_date=max_date
|
||||
)
|
||||
|
||||
session.expunge(user)
|
||||
return user, stats_before, pr.get_stats()
|
||||
|
||||
timer = ScopeTimer()
|
||||
user, before, after = purge_user_messages_tx()
|
||||
timer.stop()
|
||||
|
||||
if preview:
|
||||
logger.info(
|
||||
f"PURGE USER MESSAGES PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"PURGE USER MESSAGES: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
|
||||
)
|
||||
|
||||
logger.info(f"{before=}; {after=}")
|
||||
return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
|
||||
|
||||
@@ -45,7 +45,7 @@ def get_tree_by_frontend_id(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from oasst_backend.models import ApiClient
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -19,3 +20,22 @@ def get_leaderboard(
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.get_leaderboard(time_frame, limit=max_count)
|
||||
|
||||
|
||||
@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def update_leaderboard_time_frame(
|
||||
time_frame: UserStatsTimeFrame,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.update_stats(time_frame=time_frame)
|
||||
|
||||
|
||||
@router.post("/update", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def update_leaderboards_all(
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
) -> LeaderboardStats:
|
||||
usr = UserStatsRepository(db)
|
||||
return usr.update_all_time_frames()
|
||||
|
||||
@@ -7,6 +7,7 @@ from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from oasst_shared.utils import unaware_to_utc
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
@@ -29,6 +30,9 @@ def query_messages(
|
||||
"""
|
||||
Query messages.
|
||||
"""
|
||||
start_date = unaware_to_utc(start_date)
|
||||
end_date = unaware_to_utc(end_date)
|
||||
|
||||
pr = PromptRepository(db, api_client)
|
||||
messages = pr.query_messages(
|
||||
username=username,
|
||||
@@ -78,7 +82,7 @@ def get_tree(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client)
|
||||
message = pr.fetch_message(message_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id)
|
||||
tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.tree_manager import TreeManager, TreeManagerStats, TreeMessageCountStats
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
@@ -15,3 +16,34 @@ def get_message_stats(
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
return pr.get_stats()
|
||||
|
||||
|
||||
@router.get("/tree_manager/state_counts", response_model=dict[str, int])
|
||||
def get_tree_manager__state_counts(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.tree_counts_by_state()
|
||||
|
||||
|
||||
@router.get("/tree_manager/message_counts", response_model=list[TreeMessageCountStats])
|
||||
def get_tree_manager__message_counts(
|
||||
only_active: bool = True,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.tree_message_count_stats(only_active=only_active)
|
||||
|
||||
|
||||
@router.get("/tree_manager", response_model=TreeManagerStats)
|
||||
def get_tree_manager__stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
pr = PromptRepository(db, api_client)
|
||||
tm = TreeManager(db, pr)
|
||||
return tm.stats()
|
||||
|
||||
@@ -36,6 +36,8 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, client_user=request.user)
|
||||
pr.ensure_user_is_enabled()
|
||||
|
||||
tm = TreeManager(db, pr)
|
||||
task, message_tree_id, parent_message_id = tm.next_task(request.type)
|
||||
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
@@ -50,6 +50,6 @@ class JournalIntegration(SQLModel, table=True):
|
||||
)
|
||||
description: str = Field(max_length=512, primary_key=True)
|
||||
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
last_error: Optional[str] = Field(nullable=True)
|
||||
next_run: Optional[datetime] = Field(nullable=True)
|
||||
|
||||
@@ -30,7 +30,9 @@ class Message(SQLModel, table=True):
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
|
||||
)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
|
||||
@@ -17,5 +17,5 @@ class MessageEmbedding(SQLModel, table=True):
|
||||
|
||||
# In the case that the Message Embedding is created afterwards
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
@@ -19,7 +19,9 @@ class MessageReaction(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
|
||||
)
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
@@ -20,5 +20,5 @@ class MessageToxicity(SQLModel, table=True):
|
||||
|
||||
# In the case that the Message Embedding is created afterwards
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
@@ -20,9 +20,9 @@ class Task(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
|
||||
@@ -17,7 +17,9 @@ class TextLabels(SQLModel, table=True):
|
||||
)
|
||||
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
|
||||
sa_column=sa.Column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
|
||||
),
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
|
||||
@@ -21,7 +21,7 @@ class User(SQLModel, table=True):
|
||||
auth_method: str = Field(nullable=False, max_length=128, default="local")
|
||||
display_name: str = Field(nullable=False, max_length=256)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
|
||||
@@ -26,11 +26,11 @@ class UserStats(SQLModel, table=True):
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
rank: int = Field(nullable=True)
|
||||
|
||||
@@ -28,8 +28,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@@ -53,6 +52,13 @@ class PromptRepository:
|
||||
)
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def ensure_user_is_enabled(self):
|
||||
if self.user is None or self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
if self.user.deleted or not self.user.enabled:
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
@@ -146,6 +152,8 @@ class PromptRepository:
|
||||
review_result: bool = False,
|
||||
check_tree_state: bool = True,
|
||||
) -> Message:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
@@ -354,8 +362,7 @@ class PromptRepository:
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = MessageReaction(
|
||||
@@ -499,10 +506,14 @@ class PromptRepository:
|
||||
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return messages
|
||||
|
||||
def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
|
||||
def fetch_message_tree(
|
||||
self, message_tree_id: UUID, reviewed: bool = True, include_deleted: bool = False
|
||||
) -> list[Message]:
|
||||
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
|
||||
if reviewed:
|
||||
qry = qry.filter(Message.review_result)
|
||||
if not include_deleted:
|
||||
qry = qry.filter(not_(Message.deleted))
|
||||
return qry.all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
@@ -702,6 +713,21 @@ class PromptRepository:
|
||||
|
||||
return messages.all()
|
||||
|
||||
def update_children_counts(self, message_tree_id: UUID):
|
||||
sql_update_children_count = """
|
||||
UPDATE message SET children_count = cc.children_count
|
||||
FROM (
|
||||
SELECT m.id, count(c.id) - COALESCE(SUM(c.deleted::int), 0) AS children_count
|
||||
FROM message m
|
||||
LEFT JOIN message c ON m.id = c.parent_id
|
||||
WHERE m.message_tree_id = :message_tree_id
|
||||
GROUP BY m.id
|
||||
) AS cc
|
||||
WHERE message.id = cc.id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -9,14 +10,14 @@ import pydantic
|
||||
from loguru import logger
|
||||
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_backend.utils.ranking import ranked_pairs
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlalchemy.sql import text
|
||||
from sqlmodel import Session, func, not_
|
||||
from sqlmodel import Session, func, not_, text, update
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
@@ -68,6 +69,25 @@ class IncompleteRankingsRow(pydantic.BaseModel):
|
||||
orm_mode = True
|
||||
|
||||
|
||||
class TreeMessageCountStats(pydantic.BaseModel):
|
||||
message_tree_id: UUID
|
||||
state: str
|
||||
depth: int
|
||||
oldest: datetime
|
||||
youngest: datetime
|
||||
count: int
|
||||
goal_tree_size: int
|
||||
|
||||
@property
|
||||
def completed(self) -> int:
|
||||
return self.count / self.goal_tree_size
|
||||
|
||||
|
||||
class TreeManagerStats(pydantic.BaseModel):
|
||||
state_counts: dict[str, int]
|
||||
message_counts: list[TreeMessageCountStats]
|
||||
|
||||
|
||||
class TreeManager:
|
||||
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
|
||||
|
||||
@@ -130,7 +150,7 @@ class TreeManager:
|
||||
def _determine_task_availability_internal(
|
||||
self,
|
||||
num_active_trees: int,
|
||||
extensible_parents: list[ExtendibleParentRow],
|
||||
extendible_parents: list[ExtendibleParentRow],
|
||||
prompts_need_review: list[Message],
|
||||
replies_need_review: list[Message],
|
||||
incomplete_rankings: list[IncompleteRankingsRow],
|
||||
@@ -141,17 +161,17 @@ class TreeManager:
|
||||
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
|
||||
list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
||||
)
|
||||
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
|
||||
list(filter(lambda m: m.role == "assistant", replies_need_review))
|
||||
)
|
||||
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
|
||||
task_count_by_type[protocol_schema.TaskRequestType.label_prompter_reply] = len(
|
||||
list(filter(lambda m: m.role == "prompter", replies_need_review))
|
||||
)
|
||||
|
||||
@@ -171,15 +191,17 @@ class TreeManager:
|
||||
return task_count_by_type
|
||||
|
||||
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
extendible_parents = self.query_extendible_parents()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
|
||||
return self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
@@ -191,10 +213,12 @@ class TreeManager:
|
||||
|
||||
logger.debug("TreeManager.next_task()")
|
||||
|
||||
self.pr.ensure_user_is_enabled()
|
||||
|
||||
num_active_trees = self.query_num_active_trees()
|
||||
prompts_need_review = self.query_prompts_need_review()
|
||||
replies_need_review = self.query_replies_need_review()
|
||||
extensible_parents = self.query_extendible_parents()
|
||||
extendible_parents = self.query_extendible_parents()
|
||||
|
||||
incomplete_rankings = self.query_incomplete_rankings()
|
||||
if not self.cfg.rank_prompter_replies:
|
||||
@@ -224,7 +248,7 @@ class TreeManager:
|
||||
else:
|
||||
task_count_by_type = self._determine_task_availability_internal(
|
||||
num_active_trees=num_active_trees,
|
||||
extensible_parents=extensible_parents,
|
||||
extendible_parents=extendible_parents,
|
||||
prompts_need_review=prompts_need_review,
|
||||
replies_need_review=replies_need_review,
|
||||
incomplete_rankings=incomplete_rankings,
|
||||
@@ -266,7 +290,7 @@ class TreeManager:
|
||||
ranking_parent_id = random.choice(incomplete_rankings).parent_id
|
||||
|
||||
messages = self.pr.fetch_message_conversation(ranking_parent_id)
|
||||
assert len(messages) > 1 and messages[-1].id == ranking_parent_id
|
||||
assert len(messages) > 0 and messages[-1].id == ranking_parent_id
|
||||
ranking_parent = messages[-1]
|
||||
assert not ranking_parent.deleted and ranking_parent.review_result
|
||||
conversation = prepare_conversation(messages)
|
||||
@@ -356,12 +380,12 @@ class TreeManager:
|
||||
case TaskType.REPLY:
|
||||
# select a tree with missing replies
|
||||
if task_role == TaskRole.PROMPTER:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
|
||||
elif task_role == TaskRole.ASSISTANT:
|
||||
extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
|
||||
extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
|
||||
|
||||
if len(extensible_parents) > 0:
|
||||
random_parent = random.choice(extensible_parents)
|
||||
if len(extendible_parents) > 0:
|
||||
random_parent = random.choice(extendible_parents)
|
||||
|
||||
# fetch random conversation to extend
|
||||
logger.debug(f"selected {random_parent=}")
|
||||
@@ -424,6 +448,7 @@ class TreeManager:
|
||||
@async_managed_tx_method(CommitMode.COMMIT)
|
||||
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
||||
pr = self.pr
|
||||
pr.ensure_user_is_enabled()
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
@@ -488,7 +513,8 @@ class TreeManager:
|
||||
|
||||
_, task = pr.store_ranking(interaction)
|
||||
|
||||
self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
|
||||
self.update_message_ranks(task.message_tree_id, rankings_by_message)
|
||||
|
||||
case protocol_schema.TextLabels:
|
||||
logger.info(
|
||||
@@ -551,7 +577,6 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
|
||||
@@ -569,7 +594,6 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
||||
|
||||
@@ -587,22 +611,54 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.RANKING)
|
||||
return True
|
||||
|
||||
def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
|
||||
def check_condition_for_scoring_state(
|
||||
self, message_tree_id: UUID
|
||||
) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
|
||||
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
|
||||
mts: MessageTreeState
|
||||
mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
if not mts.active or mts.state != message_tree_state.State.RANKING:
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
|
||||
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
|
||||
for parent_msg_id, ranking in rankings_by_message.items():
|
||||
if len(ranking) < self.cfg.num_required_rankings:
|
||||
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
|
||||
return True, rankings_by_message
|
||||
|
||||
def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
|
||||
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
# check state, allow retry if in SCORING_FAILED state
|
||||
if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
|
||||
logger.debug(f"False {mts.active=}, {mts.state=}")
|
||||
return False
|
||||
|
||||
try:
|
||||
for rankings in rankings_by_message.values():
|
||||
sorted_messages = []
|
||||
for msg_reaction in rankings:
|
||||
sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
|
||||
logger.debug(f"SORTED MESSAGE {sorted_messages}")
|
||||
consensus = ranked_pairs(sorted_messages)
|
||||
logger.debug(f"CONSENSUS: {consensus}\n\n")
|
||||
for rank, message_id in enumerate(consensus):
|
||||
# set rank for each message_id for Message rows
|
||||
msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
|
||||
msg.rank = rank
|
||||
self.db.add(msg)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"update_message_ranks({message_tree_id=}) failed")
|
||||
self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
|
||||
return False
|
||||
|
||||
self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
|
||||
return True
|
||||
|
||||
def _calculate_acceptance(self, labels: list[TextLabels]):
|
||||
@@ -618,7 +674,7 @@ class TreeManager:
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
@@ -643,7 +699,7 @@ class TreeManager:
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
@@ -664,7 +720,7 @@ class TreeManager:
|
||||
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
|
||||
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :ranking_state -- message tree must be in ranking state
|
||||
AND m.review_result -- must be reviewed
|
||||
@@ -690,15 +746,15 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
-- find all extendible parent nodes
|
||||
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
|
||||
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
|
||||
WHERE mts.active -- only consider active trees
|
||||
AND mts.state = :growing_state -- message tree must be growing
|
||||
AND NOT m.deleted -- ignore deleted messages as parents
|
||||
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
|
||||
AND m.review_result -- parent node must have positive review
|
||||
AND NOT c.deleted -- don't count deleted children
|
||||
AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
"""
|
||||
@@ -708,7 +764,10 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply},
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
|
||||
@@ -717,8 +776,8 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
|
||||
FROM (
|
||||
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
|
||||
) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
) trees INNER JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE NOT m.deleted
|
||||
AND (
|
||||
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
|
||||
@@ -766,7 +825,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
"""Find all initial prompt messages that have no associated message tree state"""
|
||||
qry_missing_tree_states = (
|
||||
self.db.query(Message.id)
|
||||
.join(MessageTreeState, isouter=True)
|
||||
.outerjoin(MessageTreeState, Message.message_tree_id == MessageTreeState.message_tree_id)
|
||||
.filter(
|
||||
Message.parent_id.is_(None),
|
||||
Message.message_tree_id == Message.id,
|
||||
@@ -783,7 +842,7 @@ SELECT p.parent_id, mr.* FROM
|
||||
-- find parents with > 1 children
|
||||
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
|
||||
FROM message_tree_state mts
|
||||
LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
|
||||
WHERE m.review_result -- must be reviewed
|
||||
AND NOT m.deleted -- not deleted
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
@@ -792,8 +851,8 @@ SELECT p.parent_id, mr.* FROM
|
||||
GROUP BY m.parent_id, m.message_tree_id
|
||||
HAVING COUNT(m.id) > 1
|
||||
) as p
|
||||
LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
|
||||
INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
|
||||
"""
|
||||
|
||||
def query_tree_ranking_results(
|
||||
@@ -832,7 +891,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
|
||||
if tree_size > 1:
|
||||
state = message_tree_state.State.GROWING
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
|
||||
self._insert_default_state(id, state=state)
|
||||
|
||||
def query_num_active_trees(self) -> int:
|
||||
@@ -885,6 +944,194 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
active=True,
|
||||
)
|
||||
|
||||
def tree_counts_by_state(self) -> dict[str, int]:
|
||||
qry = self.db.query(
|
||||
MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
|
||||
).group_by(MessageTreeState.state)
|
||||
return {x["state"]: x["count"] for x in qry}
|
||||
|
||||
def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
|
||||
qry = (
|
||||
self.db.query(
|
||||
MessageTreeState.message_tree_id,
|
||||
func.max(Message.depth).label("depth"),
|
||||
func.min(Message.created_date).label("oldest"),
|
||||
func.max(Message.created_date).label("youngest"),
|
||||
func.count(Message.id).label("count"),
|
||||
MessageTreeState.goal_tree_size,
|
||||
MessageTreeState.state,
|
||||
)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(not_(Message.deleted))
|
||||
.group_by(MessageTreeState.message_tree_id)
|
||||
)
|
||||
|
||||
if only_active:
|
||||
qry = qry.filter(MessageTreeState.active)
|
||||
|
||||
return [TreeMessageCountStats(**x) for x in qry]
|
||||
|
||||
def stats(self) -> TreeManagerStats:
|
||||
return TreeManagerStats(
|
||||
state_counts=self.tree_counts_by_state(),
|
||||
message_counts=self.tree_message_count_stats(only_active=True),
|
||||
)
|
||||
|
||||
def get_user_messages_by_tree(
|
||||
self,
|
||||
user_id: UUID,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
) -> Tuple[dict[UUID, list[Message]], list[Message]]:
|
||||
"""Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
|
||||
associated with user_id."""
|
||||
|
||||
# query all messages of the user
|
||||
qry = self.db.query(Message).filter(Message.user_id == user_id)
|
||||
if min_date:
|
||||
qry = qry.filter(Message.created_date >= min_date)
|
||||
if max_date:
|
||||
qry = qry.filter(Message.created_date <= max_date)
|
||||
|
||||
prompts: list[Message] = []
|
||||
replies_by_tree: dict[UUID, list[Message]] = {}
|
||||
|
||||
# walk over result set and distinguish between initial prompts and replies
|
||||
for m in qry:
|
||||
m: Message
|
||||
|
||||
if m.message_tree_id == m.id:
|
||||
prompts.append(m)
|
||||
else:
|
||||
message_list = replies_by_tree.get(m.message_tree_id)
|
||||
if message_list is None:
|
||||
message_list = [m]
|
||||
replies_by_tree[m.message_tree_id] = message_list
|
||||
else:
|
||||
message_list.append(m)
|
||||
|
||||
return replies_by_tree, prompts
|
||||
|
||||
def _purge_message_internal(self, message_id: UUID) -> None:
|
||||
"""This internal function deletes a single message. It does not take care of
|
||||
descendants, children_count in parent etc."""
|
||||
|
||||
sql_purge_message = """
|
||||
DELETE FROM journal j USING message m WHERE j.message_id = :message_id;
|
||||
DELETE FROM message_embedding e WHERE e.message_id = :message_id;
|
||||
DELETE FROM message_toxicity t WHERE t.message_id = :message_id;
|
||||
DELETE FROM text_labels l WHERE l.message_id = :message_id;
|
||||
-- delete all ranking results that contain message
|
||||
DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN (
|
||||
SELECT t.id FROM message m
|
||||
JOIN task t ON m.parent_id = t.parent_message_id
|
||||
WHERE m.id = :message_id);
|
||||
-- delete task which inserted message
|
||||
DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id;
|
||||
DELETE FROM task t WHERE t.parent_message_id = :message_id;
|
||||
DELETE FROM message WHERE id = :message_id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
|
||||
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
|
||||
|
||||
def purge_message_tree(self, message_tree_id: UUID) -> None:
|
||||
sql_purge_message_tree = """
|
||||
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id;
|
||||
DELETE FROM task t WHERE t.message_tree_id = :message_tree_id;
|
||||
DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id;
|
||||
DELETE FROM message WHERE message_tree_id = :message_tree_id;
|
||||
"""
|
||||
r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
|
||||
logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user_messages(
|
||||
self,
|
||||
user_id: UUID,
|
||||
purge_initial_prompts: bool = True,
|
||||
min_date: datetime = None,
|
||||
max_date: datetime = None,
|
||||
):
|
||||
|
||||
# find all affected message trees
|
||||
replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date)
|
||||
total_messages = sum(len(x) for x in replies_by_tree.values())
|
||||
logger.debug(f"found: {len(replies_by_tree)} trees; {len(prompts)} prompts; {total_messages} messages;")
|
||||
|
||||
# remove all trees based on inital prompts of the user
|
||||
if purge_initial_prompts:
|
||||
for p in prompts:
|
||||
self.purge_message_tree(p.message_tree_id)
|
||||
if p.message_tree_id in replies_by_tree:
|
||||
del replies_by_tree[p.message_tree_id]
|
||||
|
||||
# patch all affected message trees
|
||||
for tree_id, replies in replies_by_tree.items():
|
||||
bad_parent_ids = set(m.id for m in replies)
|
||||
logger.debug(f"patching tree {tree_id=}, {bad_parent_ids=}")
|
||||
|
||||
tree_messages = self.pr.fetch_message_tree(tree_id, reviewed=False, include_deleted=True)
|
||||
logger.debug(f"{tree_id=}, {len(bad_parent_ids)=}, {len(tree_messages)=}")
|
||||
by_id = {m.id: m for m in tree_messages}
|
||||
|
||||
def ancestor_ids(msg: Message) -> list[UUID]:
|
||||
t = []
|
||||
while msg.parent_id is not None:
|
||||
msg = by_id[msg.parent_id]
|
||||
t.append(msg.id)
|
||||
return t
|
||||
|
||||
def is_descendant_of_deleted(m: Message) -> bool:
|
||||
if m.id in bad_parent_ids:
|
||||
return True
|
||||
ancestors = ancestor_ids(m)
|
||||
if any(a in bad_parent_ids for a in ancestors):
|
||||
return True
|
||||
return False
|
||||
|
||||
# start with deepest messages first
|
||||
tree_messages.sort(key=lambda x: x.depth, reverse=True)
|
||||
for m in tree_messages:
|
||||
if is_descendant_of_deleted(m):
|
||||
logger.debug(f"purging message: {m.id}")
|
||||
self._purge_message_internal(m.id)
|
||||
|
||||
# update childern counts
|
||||
self.pr.update_children_counts(m.message_tree_id)
|
||||
|
||||
# reactivate tree
|
||||
logger.info(f"reactivating message tree {tree_id}")
|
||||
mts = self.pr.fetch_tree_state(tree_id)
|
||||
mts.active = True
|
||||
self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
|
||||
self.check_condition_for_growing_state(tree_id)
|
||||
self.check_condition_for_ranking_state(tree_id)
|
||||
self.check_condition_for_scoring_state(tree_id)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def purge_user(self, user_id: UUID, ban: bool = True) -> None:
|
||||
self.purge_user_messages(user_id, purge_initial_prompts=True)
|
||||
|
||||
# delete all remaining rows and ban user
|
||||
sql_purge_user = """
|
||||
DELETE FROM journal WHERE user_id = :user_id;
|
||||
DELETE FROM message_reaction WHERE user_id = :user_id;
|
||||
DELETE FROM task WHERE user_id = :user_id;
|
||||
DELETE FROM message WHERE user_id = :user_id;
|
||||
DELETE FROM user_stats WHERE user_id = :user_id;
|
||||
"""
|
||||
|
||||
r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
|
||||
logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")
|
||||
|
||||
if ban:
|
||||
self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import api_auth
|
||||
@@ -901,6 +1148,10 @@ if __name__ == "__main__":
|
||||
tm = TreeManager(db, pr, cfg)
|
||||
tm.ensure_tree_states()
|
||||
|
||||
tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
|
||||
# tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
|
||||
# db.commit()
|
||||
|
||||
# print("query_num_active_trees", tm.query_num_active_trees())
|
||||
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
|
||||
# print("query_replies_need_review", tm.query_replies_need_review())
|
||||
@@ -909,10 +1160,10 @@ if __name__ == "__main__":
|
||||
# print("query_extendible_parents", tm.query_extendible_parents())
|
||||
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
|
||||
|
||||
print(
|
||||
"query_reviews_for_message",
|
||||
tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
||||
)
|
||||
# print(
|
||||
# "query_reviews_for_message",
|
||||
# tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
|
||||
# )
|
||||
|
||||
# print("next_task:", tm.next_task())
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
|
||||
from oasst_backend.models.db_payload import (
|
||||
LabelAssistantReplyPayload,
|
||||
@@ -39,12 +40,16 @@ class UserStatsRepository:
|
||||
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
|
||||
.join(UserStats, User.id == UserStats.user_id)
|
||||
.filter(UserStats.time_frame == time_frame.value)
|
||||
.order_by(UserStats.leader_score.desc())
|
||||
.order_by(UserStats.rank)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
|
||||
if len(leaderboard) > 0:
|
||||
last_update = max(x.modified_date for x in leaderboard)
|
||||
else:
|
||||
last_update = utcnow()
|
||||
return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard, last_updated=last_update)
|
||||
|
||||
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
|
||||
qry = (
|
||||
@@ -291,13 +296,11 @@ WHERE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.deps import api_auth
|
||||
from oasst_backend.database import engine
|
||||
|
||||
with Session(engine) as session:
|
||||
api_client = get_dummy_api_client(session)
|
||||
usr = UserStatsRepository(session)
|
||||
# usr.update_all_time_frames()
|
||||
# session.commit()
|
||||
# usr.get_leader_board(UserStatsTimeFrame.total)
|
||||
usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
usr = UserStatsRepository(db)
|
||||
usr.update_all_time_frames()
|
||||
db.commit()
|
||||
|
||||
@@ -19,6 +19,7 @@ class CommitMode(IntEnum):
|
||||
NONE = 0
|
||||
FLUSH = 1
|
||||
COMMIT = 2
|
||||
ROLLBACK = 3
|
||||
|
||||
|
||||
"""
|
||||
@@ -41,6 +42,8 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
@@ -75,6 +78,8 @@ def async_managed_tx_method(
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
self.db.rollback()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
@@ -118,8 +123,8 @@ def managed_tx_function(
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
elif auto_commit == CommitMode.ROLLBACK:
|
||||
session.rollback()
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def head_to_head_votes(ranks: List[List[int]]):
|
||||
tallies = np.zeros((len(ranks[0]), len(ranks[0])))
|
||||
names = sorted(ranks[0])
|
||||
ranks = np.array(ranks)
|
||||
# we want the sorted indices
|
||||
ranks = np.argsort(ranks, axis=1)
|
||||
for i in range(ranks.shape[1]):
|
||||
for j in range(i + 1, ranks.shape[1]):
|
||||
# now count the cases someone voted for i over j
|
||||
over_j = np.sum(ranks[:, i] < ranks[:, j])
|
||||
over_i = np.sum(ranks[:, j] < ranks[:, i])
|
||||
tallies[i, j] = over_j
|
||||
# tallies[i,j] = over_i
|
||||
tallies[j, i] = over_i
|
||||
# tallies[j,i] = over_j
|
||||
return tallies, names
|
||||
|
||||
|
||||
def cycle_detect(pairs):
|
||||
"""Recursively detect cylces by removing condorcet losers until either only one pair is left or condorcet loosers no longer exist
|
||||
This method upholds the invariant that in a ranking for all a,b either a>b or b>a for all a,b.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : False if the pairs do not contain a cycle, True if the pairs contain a cycle
|
||||
|
||||
|
||||
"""
|
||||
# get all condorcet losers (pairs that loose to all other pairs)
|
||||
# idea: filter all losers that are never winners
|
||||
# print("pairs", pairs)
|
||||
if len(pairs) <= 1:
|
||||
return False
|
||||
losers = [c_lose for c_lose in np.unique(pairs[:, 1]) if c_lose not in pairs[:, 0]]
|
||||
if len(losers) == 0:
|
||||
# if we recursively removed pairs, and at some point we did not have
|
||||
# a condorcet loser, that means everything is both a winner and loser,
|
||||
# yielding at least one (winner,loser), (loser,winner) pair
|
||||
return True
|
||||
|
||||
new = []
|
||||
for p in pairs:
|
||||
if p[1] not in losers:
|
||||
new.append(p)
|
||||
return cycle_detect(np.array(new))
|
||||
|
||||
|
||||
def get_winner(pairs):
|
||||
"""
|
||||
This returns _one_ concordant winner.
|
||||
It could be that there are multiple concordant winners, but in our case
|
||||
since we are interested in a ranking, we have to choose one at random.
|
||||
"""
|
||||
losers = np.unique(pairs[:, 1]).astype(int)
|
||||
winners = np.unique(pairs[:, 0]).astype(int)
|
||||
for w in winners:
|
||||
if w not in losers:
|
||||
return w
|
||||
|
||||
|
||||
def get_ranking(pairs):
|
||||
"""
|
||||
Abuses concordance property to get a (not necessarily unqiue) ranking.
|
||||
The lack of uniqueness is due to the potential existence of multiple
|
||||
equally ranked winners. We have to pick one, which is where
|
||||
the non-uniqueness comes from
|
||||
"""
|
||||
if len(pairs) == 1:
|
||||
return list(pairs[0])
|
||||
w = get_winner(pairs)
|
||||
# now remove the winner from the list of pairs
|
||||
p_new = np.array([(a, b) for a, b in pairs if a != w])
|
||||
return [w] + get_ranking(p_new)
|
||||
|
||||
|
||||
def ranked_pairs(ranks: List[List[int]]):
|
||||
"""
|
||||
Expects a list of rankings for an item like:
|
||||
[("w","x","z","y") for _ in range(3)]
|
||||
+ [("w","y","x","z") for _ in range(2)]
|
||||
+ [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x","z","w","y") for _ in range(5)]
|
||||
+ [("y","w","x","z") for _ in range(1)]
|
||||
This code is quite brain melting, but the idea is the following:
|
||||
1. create a head-to-head matrix that tallies up all win-lose combinations of preferences
|
||||
2. take all combinations that win more than they loose and sort those by how often they win
|
||||
3. use that to create an (implicit) directed graph
|
||||
4. recursively extract nodes from the graph that do not have incoming edges
|
||||
5. said recursive list is the ranking
|
||||
"""
|
||||
tallies, names = head_to_head_votes(ranks)
|
||||
tallies = tallies - tallies.T
|
||||
# print(tallies)
|
||||
# note: the resulting tally matrix should be skew-symmetric
|
||||
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
sorted_majorities = []
|
||||
for i in range(len(ranks[0])):
|
||||
for j in range(len(ranks[0])):
|
||||
if tallies[i, j] > 0:
|
||||
sorted_majorities.append((i, j, tallies[i, j]))
|
||||
# we don't explicitly deal with tied majorities here
|
||||
sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
|
||||
# now do lock ins
|
||||
lock_ins = []
|
||||
for (x, y, _) in sorted_majorities:
|
||||
# invariant: lock_ins has no cycles here
|
||||
lock_ins.append((x, y))
|
||||
# print("lock ins are now",np.array(lock_ins))
|
||||
if cycle_detect(np.array(lock_ins)):
|
||||
# print("backup: cycle detected")
|
||||
# if there's a cycle, delete the new addition and continue
|
||||
lock_ins = lock_ins[:-1]
|
||||
# now simply return all winners in order, and attach the losers
|
||||
# to the back. This is because the overall loser might not be unique
|
||||
# and (by concordance property) may never exist in any winning set to begin with.
|
||||
# (otherwise he would either not be the loser, or cycles exist!)
|
||||
# Since there could be multiple overall losers, we just return them in any order
|
||||
# as we are unable to find a closer ranking
|
||||
numerical_ranks = np.array(get_ranking(np.array(lock_ins))).astype(int)
|
||||
conversion = [names[n] for n in numerical_ranks]
|
||||
return conversion
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ranks = (
|
||||
[("w", "x", "z", "y") for _ in range(1)]
|
||||
+ [("w", "y", "x", "z") for _ in range(2)]
|
||||
# + [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x", "z", "w", "y") for _ in range(5)]
|
||||
+ [("y", "w", "x", "z") for _ in range(1)]
|
||||
# [("y","z","w","x") for _ in range(1000)]
|
||||
)
|
||||
rp = ranked_pairs(ranks)
|
||||
print(rp)
|
||||
+2
-1
@@ -18,7 +18,8 @@ services:
|
||||
|
||||
# This DB is for the FastAPI Backend.
|
||||
db:
|
||||
image: postgres
|
||||
image: ghcr.io/laion-ai/open-assistant/oasst-postgres
|
||||
pull_policy: always
|
||||
restart: always
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
FROM postgres:15
|
||||
|
||||
# install unzip
|
||||
RUN apt-get update && apt-get install -y unzip curl && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# download aws cli
|
||||
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
|
||||
RUN unzip -q awscliv2.zip
|
||||
RUN ./aws/install
|
||||
|
||||
COPY ./backup_pg_to_s3.sh .
|
||||
Executable
+15
@@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
# filename with timestamp
|
||||
filename="postgres-$(date +%Y-%m-%d_%H-%M-%S).sql"
|
||||
|
||||
# perform pg_dump
|
||||
pg_dump -U postgres postgres > /tmp/$filename
|
||||
|
||||
# upload to s3
|
||||
aws s3 cp /tmp/$filename s3://$S3_BUCKET_NAME/$filename
|
||||
|
||||
rm /tmp/$filename
|
||||
@@ -0,0 +1,3 @@
|
||||
# Frequently Asked Questions
|
||||
|
||||
In this page, there are some of the most frequently asked questions.
|
||||
@@ -0,0 +1,65 @@
|
||||
### Docker-Compose instead of Docker Compose
|
||||
|
||||
If you are using `docker-compose` instead of `docker compose` (note the " "
|
||||
instead of the "-"), you should update your docker cli to the latest version.
|
||||
`docker compose` is the most recent version and should be used instead of
|
||||
`docker-compose`
|
||||
|
||||
For more details and information check out
|
||||
[this SO thread](https://stackoverflow.com/questions/66514436/difference-between-docker-compose-and-docker-compose)
|
||||
that explains it all in detail.
|
||||
|
||||
### Pre-commit
|
||||
|
||||
We are using pre-commit to ensure the quality of the code as well as the same
|
||||
code standard.
|
||||
|
||||
The steps that you need to follow to be able to use it are:
|
||||
|
||||
```bash
|
||||
# install pre-commit in your python environment
|
||||
pip3 install pre-commit
|
||||
|
||||
# install pre-commit in your github configuration
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
So from now on, in your next commits it will run the `pre-commit` on the files
|
||||
that have been staged. If there has been any error, you will need to solve that,
|
||||
and then stage+commit again the changes.
|
||||
|
||||
## Docker Cannot Start Container: Permission Denied
|
||||
|
||||
Instead of running docker with the root command always, you could create a
|
||||
`docker` group with granted permissions (root):
|
||||
|
||||
```bash
|
||||
# Create new linux user
|
||||
sudo groupadd docker
|
||||
|
||||
# Add the actual user to the group
|
||||
sudo usermod -aG docker $USER
|
||||
|
||||
# Log in the group (apply the group changes to actual terminal session)
|
||||
newgrp docker
|
||||
```
|
||||
|
||||
After that, you should be able to run docker: `docker run .`. In the case you
|
||||
still are not able, can try to reboot terminal:
|
||||
|
||||
```bash
|
||||
reboot
|
||||
```
|
||||
|
||||
### Docker Cannot Stop Container
|
||||
|
||||
If you try to shut down the services (`docker-compose down`), and you are
|
||||
getting permission denied (using root user), you can try the following:
|
||||
|
||||
```bash
|
||||
# Restart docker daemon
|
||||
sudo systemctl restart docker.socket docker.service
|
||||
|
||||
# And remove the container
|
||||
docker rm -f <container id>
|
||||
```
|
||||
@@ -70,6 +70,15 @@ const sidebars = {
|
||||
},
|
||||
items: ["presentations/list"],
|
||||
},
|
||||
{
|
||||
type: "category",
|
||||
label: "FAQ",
|
||||
link: {
|
||||
type: "doc",
|
||||
id: "faq/README",
|
||||
},
|
||||
items: ["faq/faq"],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ defaults:
|
||||
- joke
|
||||
- gsm8k
|
||||
- samsum
|
||||
- soda_dialogue
|
||||
cache_dir: .cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
|
||||
from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
|
||||
from custom_datasets.summarization import SummarizationDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Subset
|
||||
@@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name):
|
||||
elif dataset_name == "soda":
|
||||
dataset = SODA(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "soda_dialogue":
|
||||
dataset = SODADialogue(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.1)
|
||||
elif dataset_name == "joke":
|
||||
dataset = JokeExplaination(conf.cache_dir)
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
|
||||
@@ -106,7 +106,12 @@ class SODA(Dataset):
|
||||
def process_soda_convo(self, data):
|
||||
pairs = []
|
||||
play_as = data["speakers"][1]
|
||||
prefix = "<prefix>{}. {}</prefix>".format(data["narrative"], "your name {}".format(play_as))
|
||||
prefix = "{}{}. {}{}".format(
|
||||
QA_SPECIAL_TOKENS["StartPrefix"],
|
||||
data["narrative"],
|
||||
"your name {}".format(play_as),
|
||||
QA_SPECIAL_TOKENS["EndPrefix"],
|
||||
)
|
||||
question, answer = "", ""
|
||||
prefix, postfix = "", ""
|
||||
previous_chat = []
|
||||
@@ -119,7 +124,9 @@ class SODA(Dataset):
|
||||
answer = convo
|
||||
postfix = data["speakers"][idx]
|
||||
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
|
||||
history = "<sep>".join(["{}<bot>{}".format(*p) for p in previous_chat])
|
||||
history = "<sep>".join(
|
||||
["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
|
||||
)
|
||||
if len(history):
|
||||
history += "<sep>"
|
||||
pairs.append((prefix + history + question, answer))
|
||||
@@ -148,6 +155,57 @@ class SODA(Dataset):
|
||||
return question, answer
|
||||
|
||||
|
||||
class SODADialogue(Dataset):
|
||||
url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
|
||||
|
||||
def __init__(self, cache_dir, verbose=True):
|
||||
|
||||
path = os.path.join(cache_dir, "soda_dialog.jsonl")
|
||||
|
||||
if not os.path.exists(path):
|
||||
import gzip
|
||||
import shutil
|
||||
|
||||
import gdown
|
||||
|
||||
gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz"))
|
||||
|
||||
with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in:
|
||||
with open(path, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
|
||||
self.pairs = []
|
||||
faulty = 0
|
||||
with open(path) as fin:
|
||||
for line in fin:
|
||||
conversation = json.loads(line)
|
||||
question_answer_pairs = ()
|
||||
|
||||
question_answers = conversation["text"].split("User: ")
|
||||
for question_answer in question_answers[1:]: # first element is empty
|
||||
try:
|
||||
question, answer = question_answer.split("\nAssistant: ")
|
||||
question_answer_pairs += (
|
||||
question,
|
||||
answer,
|
||||
)
|
||||
except ValueError:
|
||||
# there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble..
|
||||
faulty += 1
|
||||
continue
|
||||
|
||||
self.pairs.append(question_answer_pairs)
|
||||
|
||||
if verbose:
|
||||
print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.pairs[index]
|
||||
|
||||
|
||||
class JokeExplaination(Dataset):
|
||||
""" """
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2
|
||||
datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
evaluate==0.4.0
|
||||
gdown
|
||||
mpi4py==3.1.4
|
||||
nltk==3.8.1
|
||||
numpy==1.23.0
|
||||
PyYAML==6.0
|
||||
numpy>=1.22.4
|
||||
py7zr
|
||||
PyYAML>=6.0
|
||||
scikit_learn==1.2.0
|
||||
torch==1.13.1
|
||||
torch>=1.11.0
|
||||
transformers==4.25.1
|
||||
|
||||
@@ -0,0 +1,198 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## CodeT Code Generation Datasets\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook contains code to parse CodeT code generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n",
|
||||
"\n",
|
||||
"Requirements: `requests`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"import requests\n",
|
||||
"from typing import Dict, List, Tuple"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FILES: List[str] = [\n",
|
||||
" \"HumanEval_for_code_generation.jsonl\",\n",
|
||||
" \"mbpp_sanitized_for_code_generation.jsonl\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"OUT_FILES: List[str] = [\n",
|
||||
" \"HumanEval_codegen.jsonl\",\n",
|
||||
" \"mbpp_codegen.jsonl\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n",
|
||||
"\n",
|
||||
"OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def download_file(filename: str):\n",
|
||||
" url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n",
|
||||
" response = requests.get(url)\n",
|
||||
" with open(f\"data/{filename}\", \"wb\") as f:\n",
|
||||
" f.write(response.content)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for filename in DATA_FILES:\n",
|
||||
" download_file(filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can find the docstring, use its contents as the instruction (prefixed with \"Write a function corresponding to the docstring:\") and then use the content prior to the docstring and the canonical solution as the response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n",
|
||||
" docstring_start, docstring_end = None, None\n",
|
||||
"\n",
|
||||
" for i, line in enumerate(prompt_lines):\n",
|
||||
" if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n",
|
||||
" continue\n",
|
||||
" if docstring_start:\n",
|
||||
" docstring_end = i\n",
|
||||
" break\n",
|
||||
" docstring_start = i\n",
|
||||
"\n",
|
||||
" if docstring_end:\n",
|
||||
" return docstring_start, docstring_end\n",
|
||||
" raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_before(prompt_lines: List[str], before: int) -> List[str]:\n",
|
||||
" before_lines = prompt_lines[:before]\n",
|
||||
" return before_lines\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n",
|
||||
" between_lines = prompt_lines[start:end]\n",
|
||||
" return between_lines"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_request_and_solution(sample: dict) -> Tuple[List[str], List[str]]:\n",
|
||||
" prompt = sample[\"prompt\"]\n",
|
||||
" prompt_lines = prompt.splitlines()\n",
|
||||
"\n",
|
||||
" docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n",
|
||||
"\n",
|
||||
" # Extract prompt\n",
|
||||
" in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n",
|
||||
" if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n",
|
||||
" in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n",
|
||||
" request = \"Write a Python function corresponding to the docstring: \" + \" \".join([p.strip() for p in in_docstring])\n",
|
||||
"\n",
|
||||
" # Extract solution\n",
|
||||
" before_docstring = get_before(prompt_lines, docstring_start)\n",
|
||||
" after_docstring = sample[\"canonical_solution\"].splitlines()\n",
|
||||
" solution = before_docstring + after_docstring\n",
|
||||
" # Gets rid of consecutive empty lines\n",
|
||||
" solution = [v for i, v in enumerate(solution) if v != \"\" or v != solution[i - 1]]\n",
|
||||
" solution = \"\\n\".join(solution)\n",
|
||||
"\n",
|
||||
" return request, solution"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_file(file_path: Path, out_path: Path):\n",
|
||||
" lines = file_path.read_text().splitlines()\n",
|
||||
" samples = list(map(json.loads, lines))\n",
|
||||
"\n",
|
||||
" output = []\n",
|
||||
" for sample in samples:\n",
|
||||
" prompt, solution = get_request_and_solution(sample)\n",
|
||||
" output.append({\"prompt\": prompt, \"solution\": solution})\n",
|
||||
"\n",
|
||||
" with open(out_path, \"w\") as f:\n",
|
||||
" for sample in output:\n",
|
||||
" f.write(json.dumps(sample))\n",
|
||||
" f.write(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n",
|
||||
" process_file(file_path, out_path)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.10.5 ('venv': venv)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## CodeT Test Generation Datasets\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook contains code to parse CodeT test case generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n",
|
||||
"\n",
|
||||
"Requirements: `requests`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from pathlib import Path\n",
|
||||
"import requests\n",
|
||||
"from typing import List, Tuple"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_FILES: List[str] = [\n",
|
||||
" \"HumanEval_for_test_case_generation.jsonl\",\n",
|
||||
" \"mbpp_sanitized_for_test_case_generation.jsonl\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"OUT_FILES: List[str] = [\n",
|
||||
" \"HumanEval_testgen.jsonl\",\n",
|
||||
" \"mbpp_testgen.jsonl\",\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n",
|
||||
"\n",
|
||||
"OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def download_file(filename: str):\n",
|
||||
" url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n",
|
||||
" response = requests.get(url)\n",
|
||||
" with open(f\"data/{filename}\", \"wb\") as f:\n",
|
||||
" f.write(response.content)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for filename in DATA_FILES:\n",
|
||||
" download_file(filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n",
|
||||
" docstring_start, docstring_end = None, None\n",
|
||||
"\n",
|
||||
" for i, line in enumerate(prompt_lines):\n",
|
||||
" if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n",
|
||||
" continue\n",
|
||||
" if docstring_start:\n",
|
||||
" docstring_end = i\n",
|
||||
" break\n",
|
||||
" docstring_start = i\n",
|
||||
"\n",
|
||||
" if docstring_end:\n",
|
||||
" return docstring_start, docstring_end\n",
|
||||
" raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n",
|
||||
" between_lines = prompt_lines[start:end]\n",
|
||||
" return between_lines"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_request(sample: dict) -> List[str]:\n",
|
||||
" prompt = sample[\"prompt\"]\n",
|
||||
" prompt_lines = prompt.splitlines()\n",
|
||||
"\n",
|
||||
" docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n",
|
||||
"\n",
|
||||
" # Extract prompt\n",
|
||||
" in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n",
|
||||
" if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n",
|
||||
" in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n",
|
||||
" request = \"Write a test for a Python function with the following docstring: \" + \" \".join(\n",
|
||||
" [p.strip() for p in in_docstring]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return request\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_test_code(sample: dict) -> List[str]:\n",
|
||||
" test = sample[\"test\"]\n",
|
||||
" test_lines = test.splitlines()\n",
|
||||
" start = 0\n",
|
||||
" for i, line in enumerate(test_lines):\n",
|
||||
" if \"def check(\" in line:\n",
|
||||
" start = i\n",
|
||||
" return \"\\n\".join(test_lines[start:])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_file(file_path: Path, out_path: Path):\n",
|
||||
" lines = file_path.read_text().splitlines()\n",
|
||||
" samples = list(map(json.loads, lines))\n",
|
||||
"\n",
|
||||
" output = []\n",
|
||||
" for sample in samples:\n",
|
||||
" prompt = get_request(sample)\n",
|
||||
" test = get_test_code(sample)\n",
|
||||
" output.append({\"prompt\": prompt, \"solution\": test})\n",
|
||||
"\n",
|
||||
" with open(out_path, \"w\") as f:\n",
|
||||
" for sample in output:\n",
|
||||
" f.write(json.dumps(sample))\n",
|
||||
" f.write(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n",
|
||||
" process_file(file_path, out_path)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.10.5 ('venv': venv)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.5"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
# CodeT Datasets
|
||||
|
||||
This folder contains two notebooks.
|
||||
|
||||
One will download the data used for Microsoft CodeT for tuning a model for
|
||||
Python code generation from function docstrings, augment the data into prompt
|
||||
and solution pairs and write them to `.jsonl` files.
|
||||
|
||||
The other will download the data used for Microsoft CodeT for tuning a model for
|
||||
Python test generation from corresponding function docstrings, augment the data
|
||||
into prompt and solution pairs and write them to `.jsonl` files.
|
||||
|
||||
## Requirements
|
||||
|
||||
Both notebooks require the library `requests`.
|
||||
@@ -0,0 +1,11 @@
|
||||
# DIVERSE Downloader
|
||||
|
||||
Diverse is a notebook that downloads the DIVERSE dataset and converts it into
|
||||
OpenAssistant Data Scheme formats.
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
Feel free to contribute to this notebook. It's not perfect and additional
|
||||
functionality is planned.
|
||||
@@ -0,0 +1,477 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "00b2848c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Diverse Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"[](https://colab.research.google.com/drive/1CmXjXVrmPtpAVBaogBSuDclM0O6Zzewf?usp=sharing)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d81932b9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The purpose of this notebook is to download the DIVERSE dataset and convert it into a format that can be used for training the OpenAssistant.\n",
|
||||
"\n",
|
||||
"The DIVERSE repo can be found here: https://github.com/microsoft/CodeT/tree/main/DIVERSE\n",
|
||||
"\n",
|
||||
"If you extend or use this work, please cite the relevant papers:\n",
|
||||
"```\n",
|
||||
"@article{li2022advance,\n",
|
||||
" title={On the Advance of Making Language Models Better Reasoners},\n",
|
||||
" author={Li, Yifei and Lin, Zeqi and Zhang, Shizhuo and Fu, Qiang and Chen, Bei and Lou, Jian-Guang and Chen, Weizhu},\n",
|
||||
" journal={arXiv preprint arXiv:2206.02336},\n",
|
||||
" year={2022}\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a8c98078",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# OpenAssistant Data Scheme"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2731f88f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We will use the data scheme that can be found in the docs for Open-Assistant. This code is taken from the StackExchange notebook."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "d35ab066",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import TypeVar, List, Dict, Any, Literal\n",
|
||||
"from json import JSONEncoder\n",
|
||||
"\n",
|
||||
"T = TypeVar(\"T\", bound=\"ConversationTreeNode\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ConversationTreeNode:\n",
|
||||
" text: str # The text of the node\n",
|
||||
" role: Literal[\"prompter\", \"assistant\"] # Whether the node is a user prompt/follow-up or an assistant response\n",
|
||||
" children: List[T] # The children of the node (if you have a linear conversation, this will be of length 0 or 1)\n",
|
||||
" metadata: Dict[str, Any] # Node metadata (see below)\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self, text: str, role: Literal[\"prompter\", \"assistant\"], children: List[T], metadata: Dict[str, Any]\n",
|
||||
" ) -> None:\n",
|
||||
" self.text = text\n",
|
||||
" self.role = role\n",
|
||||
" self.children = children\n",
|
||||
" self.metadata = metadata\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ConversationTree:\n",
|
||||
" root: ConversationTreeNode # The node containing the initial prompt\n",
|
||||
" metadata: Dict[str, Any] # Tree metadata, different from root node metadata.\n",
|
||||
"\n",
|
||||
" def __init__(self, root: ConversationTreeNode, metadata: Dict[str, Any]) -> None:\n",
|
||||
" self.root = root\n",
|
||||
" self.metadata = metadata\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# subclass JSONEncoder\n",
|
||||
"class TreeEncoder(JSONEncoder):\n",
|
||||
" def default(self, o):\n",
|
||||
" return o.__dict__"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e7457bae",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Download and convert"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "54b0fd63",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We firstly import pandas and any other libraries that we'll need."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "9317d4b4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "62dc4e18",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following is a simple function to take the data (which has two columns) and convert it to a tree with a root note (question) and one child (answer)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "963e0d92",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def convert_diverse(dataset_json_path):\n",
|
||||
" # read files using pandas\n",
|
||||
" ds = pd.read_json(dataset_json_path, lines=True)\n",
|
||||
"\n",
|
||||
" # create dataset name from path\n",
|
||||
" ds_name = \"diverse\" + file.split(\"data\")[-1].replace(\"/\", \"_\").split(\".\")[0]\n",
|
||||
" print(\"*****\", ds_name, \"****\")\n",
|
||||
" print(\"Example of raw dataset\")\n",
|
||||
" print(ds.head(2))\n",
|
||||
"\n",
|
||||
" # create conversation forest\n",
|
||||
" # Print first sample so the user of this notebook has an idea of what he's looking at\n",
|
||||
" first_sample = True\n",
|
||||
" print(\"\\nExamples from converted dataset\")\n",
|
||||
" conversation_forest = []\n",
|
||||
" for item in ds[\"context\"]:\n",
|
||||
" # build nodes and tree\n",
|
||||
" # Find all answers:\n",
|
||||
"\n",
|
||||
" answers = re.findall(r\"Answer:?(.*?)#\", item.replace(\"\\n\", \" \"))\n",
|
||||
" questions = re.findall(r\"Question:?(.*?) Answer:\", item.replace(\"\\n\", \" \"))\n",
|
||||
"\n",
|
||||
" # The last question does not contain an aswer so we drop it every time.\n",
|
||||
" if len(answers) < len(questions):\n",
|
||||
" questions.pop(-1)\n",
|
||||
"\n",
|
||||
" for (answer, question) in zip(answers, questions):\n",
|
||||
" if first_sample:\n",
|
||||
" print(f\"Q: {question}\")\n",
|
||||
" print(f\"A: {answer}\")\n",
|
||||
" root = ConversationTreeNode(text=question, role=\"prompter\", children=[], metadata=None)\n",
|
||||
" child = ConversationTreeNode(text=answer, role=\"assistant\", children=[], metadata=None)\n",
|
||||
" root.children.append(child)\n",
|
||||
" conversation_tree = ConversationTree(root=root, metadata={\"dataset\": ds_name})\n",
|
||||
" conversation_forest.append(conversation_tree)\n",
|
||||
"\n",
|
||||
" first_sample = False\n",
|
||||
"\n",
|
||||
" conversation_forest_json = [\n",
|
||||
" json.loads(TreeEncoder().encode(conversation_tree)) for conversation_tree in conversation_forest\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" print(json.dumps(conversation_forest_json, indent=4), file=open(f\"./{ds_name}.json\", \"w+\"))\n",
|
||||
" print(\"\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e4448c9a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We now clone the repository containing the dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "06e7719e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cloning into 'CodeT'...\r\n",
|
||||
"remote: Enumerating objects: 144, done.\u001B[K\r\n",
|
||||
"remote: Counting objects: 100% (16/16), done.\u001B[K\r\n",
|
||||
"remote: Compressing objects: 100% (16/16), done.\u001B[K\r\n",
|
||||
"remote: Total 144 (delta 1), reused 0 (delta 0), pack-reused 128\u001B[Kving objects: 8% (12/144), 3.70 MiB | 1.35 MiB/s Receiving objects: 12% (18/144), 5.13 MiB | 1.58 MiB/s Receiving objects: 13% (19/144), 11.36 MiB | 2.39 MiB/s Receiving objects: 13% (20/144), 19.19 MiB | 3.97 MiB/s Receiving objects: 15% (22/144), 19.19 MiB | 3.97 MiB/s Receiving objects: 23% (34/144), 22.15 MiB | 4.37 MiB/s Receiving objects: 27% (39/144), 22.15 MiB | 4.37 MiB/s Receiving objects: 29% (42/144), 25.30 MiB | 4.77 MiB/s Receiving objects: 32% (47/144), 28.71 MiB | 5.22 MiB/s Receiving objects: 41% (60/144), 32.41 MiB | 5.62 MiB/s Receiving objects: 54% (78/144), 32.41 MiB | 5.62 MiB/s Receiving objects: 60% (87/144), 39.34 MiB | 6.20 MiB/s Receiving objects: 61% (88/144), 47.14 MiB | 6.79 MiB/s Receiving objects: 64% (93/144), 51.36 MiB | 7.12 MiB/s Receiving objects: 66% (96/144), 55.58 MiB | 7.40 MiB/s \r\n",
|
||||
"Receiving objects: 100% (144/144), 56.76 MiB | 4.97 MiB/s, done.\r\n",
|
||||
"Resolving deltas: 100% (33/33), done.\r\n",
|
||||
"Checking out files: 100% (64/64), done.\r\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!git clone https://github.com/microsoft/CodeT.git"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "89a166c9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"diverse_files = [\n",
|
||||
" \"CodeT/DIVERSE/data/sqa/split1/test.jsonl\",\n",
|
||||
" \"CodeT/DIVERSE/data/sqa/split1/train.jsonl\",\n",
|
||||
" \"CodeT/DIVERSE/data/sqa/split2/test.jsonl\",\n",
|
||||
" \"CodeT/DIVERSE/data/sqa/split2/train.jsonl\",\n",
|
||||
" \"CodeT/DIVERSE/data/gsm8k/test.jsonl\",\n",
|
||||
" \"CodeT/DIVERSE/data/gsm8k/train.jsonl\",\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "2da75b14",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"***** diverse_sqa_split1_test ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nIs clerk of Supreme Court of Canada... \n",
|
||||
"1 Question:\\nIs Saturn named after king of gods ... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [Snoopy is a dog.\\nChance is a dog.\\nDogs look... \n",
|
||||
"1 [Snoopy is a cartoon dog.\\nChance is a cartoon... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'type': 'solution', 'question': 'Does Snoopy ... \n",
|
||||
"1 {'type': 'solution', 'question': 'Does Snoopy ... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: Is clerk of Supreme Court of Canada safe profession for someone with seismophobia?\n",
|
||||
"A: The Supreme Court of Canada is in Ottawa, Canada. Ottawa is in Ontario, Canada. Ontario is in the Canadian Shield. The Canadian Shield is a stable tectonic plate. Thus, Ottawa is not prone to earthquakes. Thus, the clerk of the Supreme Court of Canada is a safe profession for someone with seismophobia. So the answer is yes.\n",
|
||||
"Q: During the Cuban revolution, did the US experience a population boom?\n",
|
||||
"A: The Cuban revolution was in 1959. The US population in 1959 was about 180 million. The US population in 2010 was about 310 million. Thus, the US population increased by about 130 million. So the answer is yes.\n",
|
||||
"Q: Can the largest crustacean stretch out completely on a king-sized mattress?\n",
|
||||
"A: The largest crustacean is the Japanese spider crab. The Japanese spider crab has a leg span of 3.8 meters. A king-sized mattress is 1.9 meters wide. Thus, the Japanese spider crab could not stretch out completely on a king-sized mattress. So the answer is no.\n",
|
||||
"Q: Could morphine cure HIV?\n",
|
||||
"A: Morphine is a painkiller. HIV is a virus. Painkillers do not cure viruses. Thus, morphine could not cure HIV. So the answer is no.\n",
|
||||
"Q: Is Christopher Nolan indebted to Bob Kane?\n",
|
||||
"A: Bob Kane created Batman. Christopher Nolan directed the Dark Knight trilogy. The Dark Knight trilogy is about Batman. Thus, Christopher Nolan is indebted to Bob Kane. So the answer is yes.\n",
|
||||
"Q: In baseball, is a \"Homer\" named after the poet Homer who wrote the Odyssey?\n",
|
||||
"A: A \"Homer\" is a home run. The term \"home run\" was coined by Harry Wright in 1858. Homer lived from about 800 BC to about 700 BC. Thus, the term \"home run\" was not named after the poet Homer. So the answer is no.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"***** diverse_sqa_split1_train ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nIs eating a Dicopomorpha echmeptery... \n",
|
||||
"1 Question:\\nWas Alexander the Great baptized?\\n... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [New York Public Library is a library.\\nSix Fl... \n",
|
||||
"1 [The New York Public Library is in New York Ci... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'type': 'solution', 'question': 'Could you go... \n",
|
||||
"1 {'type': 'solution', 'question': 'Could you go... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: Is eating a Dicopomorpha echmepterygis size Uranium pellet fatal?\n",
|
||||
"A: The Dicopomorpha echmepterygis is a beetle. The mass of a beetle is about 0.1 g. The mass of a Uranium pellet is about 10 g. Thus, eating a Dicopomorpha echmepterygis is not the same as eating a Uranium pellet. So the answer is no.\n",
|
||||
"Q: Does Jerry Seinfeld hang out at the Budweiser Party Deck?\n",
|
||||
"A: Jerry Seinfeld is a comedian. Comedians perform at comedy clubs. The Budweiser Party Deck is not a comedy club. Thus, Jerry Seinfeld does not hang out at the Budweiser Party Deck. So the answer is no.\n",
|
||||
"Q: Does The Jungle Book contain racist subtext?\n",
|
||||
"A: The Jungle Book is a book about a boy raised by wolves. The boy is white. The wolves are black. Thus, The Jungle Book contains racist subtext. So the answer is yes.\n",
|
||||
"Q: Does Dean Cain have less days to birthday than Will Ferrell every 4th of July?\n",
|
||||
"A: Dean Cain was born on July 31. Will Ferrell was born on July 16. July 4 is the middle of July. Thus, Dean Cain has more days to his birthday than Will Ferrell every 4th of July. So the answer is no.\n",
|
||||
"Q: Do mail carriers need multiple uniforms?\n",
|
||||
"A: Mail carriers need to wear uniforms. Uniforms are clothes. People need multiple clothes. Thus, mail carriers need multiple uniforms. So the answer is yes.\n",
|
||||
"Q: Can Curiosity take samples of rocks from Lacus Temporis?\n",
|
||||
"A: Curiosity is a rover on Mars. Lacus Temporis is a lake on Mars. Curiosity is not in Lacus Temporis. Thus, Curiosity cannot take samples of rocks from Lacus Temporis. So the answer is no.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"***** diverse_sqa_split2_test ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nDoes the density of helium cause vo... \n",
|
||||
"1 Question:\\nIs Mark Cuban able to visit Norther... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [The New York Public Library is in New York Ci... \n",
|
||||
"1 [The distance from New York Public Library to ... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'type': 'solution', 'question': 'Could you go... \n",
|
||||
"1 {'type': 'solution', 'question': 'Could you go... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: Does the density of helium cause voices to sound deeper?\n",
|
||||
"A: The density of helium is 0.1785 g/L, which is less than air. The density of air is 1.225 g/L. The density of helium is less than air, so helium causes voices to sound higher. Thus, helium does not cause voices to sound deeper. So the answer is no.\n",
|
||||
"Q: Would Janet Jackson avoid a dish with ham?\n",
|
||||
"A: Janet Jackson is a vegetarian. Vegetarians do not eat meat. Ham is a type of meat. Thus, Janet Jackson would avoid a dish with ham. So the answer is yes.\n",
|
||||
"Q: Do people watching Coen brothers films in Guinea Bissau need subtitles?\n",
|
||||
"A: The Coen brothers are American. Americans speak English. Guinea Bissau is in Africa. Africans speak Portuguese. Thus, people in Guinea Bissau would need subtitles to watch Coen brothers films. So the answer is yes.\n",
|
||||
"Q: Could a Gladiator's weapon crush a diamond?\n",
|
||||
"A: A gladiator's weapon was a sword. The hardness of a diamond is 10. The hardness of a sword is 5. Thus, a gladiator's weapon could not crush a diamond. So the answer is no.\n",
|
||||
"Q: Can Spartina Patens thrive in the Sahara Desert?\n",
|
||||
"A: Spartina Patens is a salt marsh grass. The Sahara Desert is a desert. Deserts are dry. Thus, Spartina Patens would not thrive in the Sahara Desert. So the answer is no.\n",
|
||||
"Q: Were all the materials to make a cannon known during the bronze age?\n",
|
||||
"A: The bronze age was about 3000 BC to 1200 BC. The bronze age was before the iron age. Thus, iron was not known during the bronze age. Iron is needed to make a cannon. Thus, all the materials to make a cannon were not known during the bronze age. So the answer is no.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"***** diverse_sqa_split2_train ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nWas the Peak of the Andes hidden fr... \n",
|
||||
"1 Question:\\nIs Mark Cuban able to visit Norther... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [Snoopy is a dog.\\nChance is a dog.\\nChance is... \n",
|
||||
"1 [Snoopy is a cartoon dog.\\nChance is a cartoon... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'type': 'solution', 'question': 'Does Snoopy ... \n",
|
||||
"1 {'type': 'solution', 'question': 'Does Snoopy ... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: Was the Peak of the Andes hidden from the view of the Colossus of Rhodes?\n",
|
||||
"A: The Colossus of Rhodes was a statue of the Greek god Helios. The Colossus of Rhodes was located on the island of Rhodes. The Peak of the Andes is in South America. Thus, the Peak of the Andes was not visible from the Colossus of Rhodes. So the answer is yes.\n",
|
||||
"Q: Can you swim to Miami from New York?\n",
|
||||
"A: The distance from New York to Miami is about 1,500 miles. The fastest swimmer can swim about 2 miles per hour. Thus, it would take about 750 hours to swim from New York to Miami. Thus, you could not swim to Miami from New York. So the answer is no.\n",
|
||||
"Q: Is Freya a combination of Athena and Aphrodite?\n",
|
||||
"A: Freya is the Norse goddess of love and beauty. Athena is the Greek goddess of wisdom and war. Aphrodite is the Greek goddess of love and beauty. Thus, Freya is a combination of Athena and Aphrodite. So the answer is yes.\n",
|
||||
"Q: Were the Great Pyramids built by a theocratic government?\n",
|
||||
"A: The Great Pyramids were built by the Egyptians. The Egyptians were ruled by a pharaoh. The pharaoh was considered a god. Thus, the Great Pyramids were built by a theocratic government. So the answer is yes.\n",
|
||||
"Q: Was P. G. Wodehouse's favorite book The Hunger Games?\n",
|
||||
"A: P. G. Wodehouse's favorite book was The Pickwick Papers. The Pickwick Papers was written by Charles Dickens. The Hunger Games was written by Suzanne Collins. Thus, P. G. Wodehouse's favorite book was not The Hunger Games. So the answer is no.\n",
|
||||
"Q: Can Burundi's communicate with citizens of New Brunswick?\n",
|
||||
"A: Burundi's speak French. New Brunswick is in Canada. Canada is a bilingual country. Thus, Burundi's can communicate with citizens of New Brunswick. So the answer is yes.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"***** diverse_gsm8k_test ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nA community is building a metal fen... \n",
|
||||
"1 Question:\\nThe white rabbit can hop 15 meters ... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [She eats 3 and uses 4, so that is 7 eggs.\\n16... \n",
|
||||
"1 [She eats 3 eggs for breakfast and uses 4 in m... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'question': 'Janet’s ducks lay 16 eggs per da... \n",
|
||||
"1 {'question': 'Janet’s ducks lay 16 eggs per da... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: A community is building a metal fence. Each fence panel is made of 3 metal sheets, and 2 metal beams. The fence is made of 10 fence panels. If each sheet is made of 10 metal rods and each metal beam is made of 4 metal rods, how many metal rods does the community need for the fence?\n",
|
||||
"A: In each panel, the metal sheets use 3 metal sheets * 10 metal rods = <<3*10=30>>30 metal rods. In each panel, the metal beams use 2 metal beams * 4 metal rods = <<2*4=8>>8 metal rods. So each panel uses 30 + 8 = <<30+8=38>>38 metal rods. The entire fence therefore needs 38 metal rods * 10 fence panels = <<38*10=380>>380 metal rods. \n",
|
||||
"Q: John buys 3 dress shirts. They sell for $20 each. He also has to pay 10% tax on everything. How much did he pay in total?\n",
|
||||
"A: The shirts cost 3*$20=$<<3*20=60>>60 before tax The tax cost $60*.1=$<<60*.1=6>>6 So in total they paid $60+$6=$<<60+6=66>>66 \n",
|
||||
"Q: Bob gets rent assistance because he's low-income. If he gets a raise of $0.50/hour and works 40 hours a week, how much more will he actually earn a week if his housing benefit is reduced by $60/month?\n",
|
||||
"A: First find the total increase in Bob's earnings: $0.50/hour * 40 hours/week = $<<0.50*40=20>>20/week Then find the weekly decrease in Bob's housing assistance: $60/month / 4 weeks/month = $<<60/4=15>>15/week Then subtract the lost assistance from the increased wages to find Bob's net increase in money: $20/week - $15/week = $<<20-15=5>>5/week \n",
|
||||
"Q: Annie plants 3 pots of basil, 9 pots of rosemary, and 6 pots of thyme. Each basil plant has 4 leaves, each rosemary plant has 18 leaves, and each thyme plant has 30 leaves. How many leaves are there total?\n",
|
||||
"A: First find the total number of basil leaves: 3 pots * 4 leaves/pot = <<3*4=12>>12 leaves Then find the total number of rosemary leaves: 9 pots * 18 leaves/pot = <<9*18=162>>162 leaves Then find the total number of thyme leaves: 6 pots * 30 leaves/pot = <<6*30=180>>180 leaves Then add the number of each type of leaf to find the total number of leaves: 12 leaves + 162 leaves + 180 leaves = <<12+162+180=354>>354 leaves \n",
|
||||
"Q: There are 7 mL of solution in each of 6 test tubes. Dr. Igor takes all of the solution and then evenly distributes it into 3 beakers. How many mL of solution are in each beaker?\n",
|
||||
"A: 7 * 6 = <<7*6=42>>42 mL 42/3 = <<42/3=14>>14 mL Each beaker holds 14 mL of solution. \n",
|
||||
"Q: Janet pays $40/hour for 3 hours per week of clarinet lessons and $28/hour for 5 hours a week of piano lessons. How much more does she spend on piano lessons than clarinet lessons in a year?\n",
|
||||
"A: First find the total Janet spends on clarinet lessons per week: $40/hour * 3 hours/week = $<<40*3=120>>120/week Then find the total Janet spends on piano lessons per week: $28/hour * 5 hours/week = $<<28*5=140>>140/week Then subtract her weekly clarinet spending from her weekly piano spending to find the weekly difference: $140/week - $120/week = $<<140-120=20>>20/week Then multiply the weekly difference by the number of weeks in a year to find the annual difference: $20/week * 52 weeks/year = $<<20*52=1040>>1040/year \n",
|
||||
"Q: A normal lemon tree produces 60 lemons per year. Jim has specially engineered lemon trees that produce 50% more lemons per year. He has a grove that is 50 trees by 30 trees. How many lemons does he produce in 5 years?\n",
|
||||
"A: Each tree produces 60*.5=<<60*.5=30>>30 more lemons than normal So they each produce 60+30=<<60+30=90>>90 lemons He has 50*30=<<50*30=1500>>1500 trees So every year he produces 1500*90=<<1500*90=135000>>135000 lemons That means he produces 135000*5=<<135000*5=675000>>675,000 \n",
|
||||
"Q: Billy weighs 9 pounds more than Brad. Brad weighs 5 pounds more than Carl. If Carl weighs 145 pounds, how much does Billy weigh, in pounds?\n",
|
||||
"A: Brad weighs 145+5=<<145+5=150>>150 pounds. Billy weighs 150+9=<<150+9=159>>159 pounds. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"***** diverse_gsm8k_train ****\n",
|
||||
"Example of raw dataset\n",
|
||||
" context \\\n",
|
||||
"0 Question:\\nA magician has a top hat with 20 re... \n",
|
||||
"1 Question:\\nA community is building a metal fen... \n",
|
||||
"\n",
|
||||
" samples \\\n",
|
||||
"0 [There are 125 cars in total\\n64% of them are ... \n",
|
||||
"1 [64% of 125 is <<64/100*125=80>>80.\\n8% of 125... \n",
|
||||
"\n",
|
||||
" metadata \n",
|
||||
"0 {'question': 'Pauline has 125 matchbox cars. T... \n",
|
||||
"1 {'question': 'Pauline has 125 matchbox cars. T... \n",
|
||||
"\n",
|
||||
"Examples from converted dataset\n",
|
||||
"Q: A magician has a top hat with 20 red marbles and a top hat with 30 blue marbles. If he takes away 3 red marbles and four times as many blue marbles as red marbles (without looking), how many marbles in total does he have left?\n",
|
||||
"A: He had 20 red marbles and took away 3 leaving 20-3 = <<20-3=17>>17 red marbles He took 4 times as many blue marbles as red marbles which is 4*3 = <<4*3=12>>12 blue marbles He took 12 blue marbles from 30 leaving 30-12 = 18 blue marbles He now has 17+18 = <<17+18=35>>35 marbles left \n",
|
||||
"Q: Lucas wants to get a dog but his parents think he already has too many pets and won't have enough space. He already has 12 pet beds in his room but manages to fit another 8 pet beds. His parents argue that each pet is going to need 2 beds each to feel comfortable. According to his parent's argument, how many pets does Lucas have enough room for?\n",
|
||||
"A: Lucas has a total of 12 existing pet beds + 8 new pet beds = <<12+8=20>>20 pet beds. So according to his parents, Lucas has enough room for 20 pet beds / 2 pet beds per pet = <<20/2=10>>10 pets. \n",
|
||||
"Q: Super Clean Car Wash Company cleans 80 cars per day. They make $5 per car washed. How much money will they make in 5 days?\n",
|
||||
"A: Each day they will make 80 × $5 = $<<80*5=400>>400. They will make $400 × 5 = $<<400*5=2000>>2000 in 5 days. \n",
|
||||
"Q: Eighteen hours ago, Beth and I took 100 photographs of our project. Today, Beth and I will take 20% fewer photographs of the same project. If we were to take 300 photographs of the project, how many photographs would we take to reach the target?\n",
|
||||
"A: If you took 100 photographs of the project 18 hours ago, and today 20% few photographs have been taken, then 20/100*100 = 20 fewer photographs of the project have been taken today. The total number of photographs of the project that have been taken today is 100-20 = <<100-20=80>>80 So far, you've taken 80+100 = <<80+100=180>>180 photographs of the project. Since the target number of photographs is 300, the number of photographs that you need to take to reach the target is 300-180 = <<300-180=120>>120 \n",
|
||||
"Q: Ruby was going to order pizza for dinner. Her son would only eat pepperoni pizza. Her daughter would only eat sausage. Ruby and her husband wanted black olive and mushroom pizza. To make life easy, Ruby decided to order an entire pizza for each of her children and she would split one with her husband. The pizza restaurant charged $10.00 per pizza and $1.00 per topping. She also needed to add a $5.00 tip. Including tip, how much would the pizza order cost?\n",
|
||||
"A: Ruby was going to order 1 for her son, 1 for her daughter and 1 to share with her husband. So she needed to order 1+1+1 = <<1+1+1=3>>3 pizzas Each pizza cost $10 and she was ordering 3 so that comes to 10*3 = $<<10*3=30.00>>30.00 She needed to order pepperoni, sausage, black olive and mushroom, which came to 4 toppings at $1.00 each so 4*1 = $<<4*1=4.00>>4.00 extra for toppings The pizzas cost $30 and $4 for the toppings so the total costs of the pizzas came to 30+4 = $<<30+4=34.00>>34.00 She also had to add a $5.00 tip to her current order total of $34.00 so 5+34.00 = $<<5+34=39.00>>39.00 for the total order \n",
|
||||
"Q: Yves and his siblings ordered pizza and asked to have it cut into 16 slices. During dinner time, they only ate one-fourth of it. The next day, Yves ate one-fourth of the remaining pizza. Then his two siblings ate 2 slices each. How many slices of pizza were left?\n",
|
||||
"A: During dinner time, Yves and his siblings ate 16/4 = <<16/4=4>>4 slices. So the next day, there were still 16 - 4 = <<16-4=12>>12 slices left. The next day, Yves ate 12/4 = <<12/4=3>>3 slices of pizza. Thus, there were 12 - 3 = <<12-3=9>>9 slices left. Then, his two siblings ate 2 x 2 = <<2*2=4>>4 slices of pizza. Therefore, there were still 9 - 4 = <<9-4=5>>5 slices of pizza left. \n",
|
||||
"Q: Pria bought a new car that advertised an estimated gas mileage of 35 miles per gallon. The car has a 12-gallon tank. She filled her car full of gas and was able to drive a total of 372 miles. What was the difference, in miles per gallon, between Pria's mileage and the advertised mileage?\n",
|
||||
"A: Pria's car achieved a rate of 372 miles / 12 gallons = <<372/12=31>>31 miles per gallon. Therefore, it was a difference of 35 - 31 = <<35-31=4>>4 miles per gallon. \n",
|
||||
"Q: Janet has 24 dresses. Half of them have pockets. Of those, a third have 2 pockets and the rest have 3 pockets. How many total pockets do her dresses have?\n",
|
||||
"A: She has 24/2=<<24/2=12>>12 dresses with pockets Of those 12/3=4 have 2 pockets So 12-4=<<12-4=8>>8 have three pockets So the dresses with 2 pockets have 2*4=<<2*4=8>>8 pockets The other dresses contribute 8*3=<<8*3=24>>24 pockets So she has a total of 8+24=<<8+24=32>>32 pockets \n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for file in diverse_files:\n",
|
||||
" convert_diverse(file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"name": "conda-env-open-assistant-py",
|
||||
"language": "python",
|
||||
"display_name": "Python [conda env:open-assistant]"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.0"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "25d5c2324055587ceaeef27650c79ce8358ea61d7689f2e0b8ada5d53f85bce4"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum):
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
INVALID_MESSAGE = 2008
|
||||
@@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_ASSIGNED_TO_USER = 2106
|
||||
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
HUGGINGFACE_API_ERROR = 3001
|
||||
|
||||
# 4000-5000: user
|
||||
USER_NOT_SPECIFIED = 4000
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
@@ -392,6 +392,7 @@ class UserScore(BaseModel):
|
||||
|
||||
class LeaderboardStats(BaseModel):
|
||||
time_frame: str
|
||||
last_updated: datetime
|
||||
leaderboard: List[UserScore]
|
||||
|
||||
|
||||
|
||||
@@ -10,14 +10,47 @@ def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def unaware_to_utc(d: datetime | None) -> datetime:
|
||||
"""Set timezeno to UTC if datetime is unaware (tzinfo == None)."""
|
||||
if d and d.tzinfo is None:
|
||||
return d.replace(tzinfo=timezone.utc)
|
||||
return d
|
||||
|
||||
|
||||
class TimerError(Exception):
|
||||
"""A custom exception used to report errors in use of Timer class"""
|
||||
|
||||
|
||||
class ScopeTimer:
|
||||
def __init__(self):
|
||||
self.start()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Measure new start time"""
|
||||
self.start_time = time.perf_counter()
|
||||
|
||||
def stop(self) -> float:
|
||||
"""Store and return the elapsed time"""
|
||||
self.elapsed = time.perf_counter() - self.start_time
|
||||
return self.elapsed
|
||||
|
||||
def __enter__(self):
|
||||
"""Start a new timer as a context manager"""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
"""Stop the context manager timer"""
|
||||
self.stop()
|
||||
|
||||
|
||||
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
start = time.time()
|
||||
timer = ScopeTimer()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
elapsed = timer.stop()
|
||||
if log_kwargs:
|
||||
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
|
||||
|
||||
@@ -101,7 +101,7 @@ def ranked_pairs(ranks: List[List[int]]):
|
||||
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
|
||||
sorted_majorities = []
|
||||
for i in range(len(ranks[0])):
|
||||
for j in range(len(ranks[i])):
|
||||
for j in range(len(ranks[0])):
|
||||
if tallies[i, j] > 0:
|
||||
sorted_majorities.append((i, j, tallies[i, j]))
|
||||
# we don't explicitly deal with tied majorities here
|
||||
@@ -132,8 +132,8 @@ if __name__ == "__main__":
|
||||
[("w", "x", "z", "y") for _ in range(1)]
|
||||
+ [("w", "y", "x", "z") for _ in range(2)]
|
||||
# + [("x","y","z","w") for _ in range(4)]
|
||||
+ [("x", "z", "w", "y") for _ in range(5)]
|
||||
+ [("y", "w", "x", "z") for _ in range(1)]
|
||||
# + [("x", "z", "w", "y") for _ in range(5)]
|
||||
# + [("y", "w", "x", "z") for _ in range(1)]
|
||||
# [("y","z","w","x") for _ in range(1000)]
|
||||
)
|
||||
rp = ranked_pairs(ranks)
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import http
|
||||
import random
|
||||
|
||||
import requests
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
# debug constants
|
||||
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
|
||||
|
||||
|
||||
def _random_message_id():
|
||||
return str(random.randint(1000, 9999))
|
||||
|
||||
|
||||
def _render_message(message: dict) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message["is_assistant"]:
|
||||
return f"Assistant: {message['text']}"
|
||||
return f"Prompter: {message['text']}"
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""automates tasks"""
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
if response.status_code == http.HTTPStatus.NO_CONTENT:
|
||||
return None
|
||||
return response.json()
|
||||
|
||||
def gen_random_text():
|
||||
return " ".join([random.choice(["hello", "world", "foo", "bar"]) for _ in range(10)])
|
||||
|
||||
def gen_random_ranking(messages):
|
||||
"""rank messages randomly and return list of indexes in order of rank randomly"""
|
||||
print("Ranking")
|
||||
print(messages)
|
||||
print(len(messages))
|
||||
ranks = [i for i in range(len(messages))]
|
||||
shuffled = random.shuffle(ranks)
|
||||
print(ranks)
|
||||
print(shuffled)
|
||||
return ranks
|
||||
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
q = 0
|
||||
while tasks:
|
||||
task = tasks.pop(0)
|
||||
print(task)
|
||||
|
||||
match (task["type"]):
|
||||
case "initial_prompt":
|
||||
typer.echo("Please provide an initial prompt to the assistant.")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
prompt = gen_random_text()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_initial_prompt":
|
||||
typer.echo("Label the following prompt:")
|
||||
typer.echo(task["prompt"])
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
|
||||
# send labels
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["prompt"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "prompter_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "assistant_reply":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"user_message_id": user_message_id,
|
||||
"text": gen_random_text(),
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["replies"])
|
||||
print(ranking)
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"task_id": task["id"],
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_initial_prompts":
|
||||
# acknowledge task
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
# send interaction
|
||||
ranking = gen_random_ranking(task["prompots"])
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "label_prompter_reply" | "label_assistant_reply":
|
||||
# acknowledge task
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
|
||||
typer.echo("Label the following reply:")
|
||||
typer.echo(task["reply"])
|
||||
message_id = _random_message_id()
|
||||
user_message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
valid_labels = task["valid_labels"]
|
||||
|
||||
labels_dict = None
|
||||
if task["mode"] == "simple" and len(valid_labels) == 1:
|
||||
answer = random.choice([True, False])
|
||||
labels_dict = {valid_labels[0]: 1 if answer else 0}
|
||||
else:
|
||||
while labels_dict is None:
|
||||
labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
|
||||
|
||||
if all([label in valid_labels for label in labels]):
|
||||
labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
|
||||
else:
|
||||
invalid_labels = [label for label in labels if label not in valid_labels]
|
||||
typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_labels",
|
||||
"message_id": task["message_id"],
|
||||
"task_id": task["id"],
|
||||
"text": task["reply"],
|
||||
"labels": labels_dict,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
case "task_done":
|
||||
typer.echo("Task done!")
|
||||
# rerun with new task slected from above cases
|
||||
# add a new task
|
||||
q += 1
|
||||
if q == 10:
|
||||
typer.echo("Task done!")
|
||||
break
|
||||
tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
|
||||
#
|
||||
case _:
|
||||
typer.echo(f"Unknown task type {task['type']}")
|
||||
# rerun with new task slected from above cases
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -27,13 +27,6 @@ describe("signin flow", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
it("shows the logged in users email address if logged in with email", () => {
|
||||
const emailAddress = "user@example.com";
|
||||
cy.signInWithEmail(emailAddress);
|
||||
// The user will only see the email address if the window is wide enough, not technically required as even when hidden this will find it in the page.
|
||||
cy.viewport(1920, 1000);
|
||||
cy.contains('[data-cy="username"]', emailAddress);
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
module.exports = {
|
||||
i18n: {
|
||||
defaultLocale: "en",
|
||||
locales: ["en"],
|
||||
},
|
||||
};
|
||||
@@ -1,4 +1,6 @@
|
||||
/** @type {import('next').NextConfig} */
|
||||
const { i18n } = require("./next-i18next.config");
|
||||
|
||||
const nextConfig = {
|
||||
output: "standalone",
|
||||
reactStrictMode: true,
|
||||
@@ -16,6 +18,7 @@ const nextConfig = {
|
||||
*/
|
||||
// scrollRestoration: true,
|
||||
},
|
||||
i18n,
|
||||
};
|
||||
|
||||
module.exports = nextConfig;
|
||||
|
||||
Generated
+188
-13
@@ -34,6 +34,7 @@
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
"next-i18next": "^13.0.3",
|
||||
"nodemailer": "^6.8.0",
|
||||
"npm": "^9.2.0",
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
@@ -41,11 +42,13 @@
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-i18next": "^12.1.4",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
"use-debounce": "^9.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
@@ -12685,6 +12688,15 @@
|
||||
"@types/unist": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/hoist-non-react-statics": {
|
||||
"version": "3.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/hoist-non-react-statics/-/hoist-non-react-statics-3.3.1.tgz",
|
||||
"integrity": "sha512-iMIqiko6ooLrTh1joXodJK5X9xeEALT1kM5G3ZLhD3hszxBdIEd5C75U834D9mLcINgD4OyZf5uQXjkuYydWvA==",
|
||||
"dependencies": {
|
||||
"@types/react": "*",
|
||||
"hoist-non-react-statics": "^3.3.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/html-minifier-terser": {
|
||||
"version": "6.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz",
|
||||
@@ -12891,8 +12903,7 @@
|
||||
"node_modules/@types/prop-types": {
|
||||
"version": "15.7.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.5.tgz",
|
||||
"integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==",
|
||||
"devOptional": true
|
||||
"integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w=="
|
||||
},
|
||||
"node_modules/@types/qs": {
|
||||
"version": "6.9.7",
|
||||
@@ -12904,7 +12915,6 @@
|
||||
"version": "18.0.26",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.0.26.tgz",
|
||||
"integrity": "sha512-hCR3PJQsAIXyxhTNSiDFY//LhnMZWpNNr5etoCqx/iUfGc5gXWtQR2Phl908jVR6uPXacojQWTg4qRpkxTuGug==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"@types/prop-types": "*",
|
||||
"@types/scheduler": "*",
|
||||
@@ -12923,8 +12933,7 @@
|
||||
"node_modules/@types/scheduler": {
|
||||
"version": "0.16.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.16.2.tgz",
|
||||
"integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew==",
|
||||
"devOptional": true
|
||||
"integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew=="
|
||||
},
|
||||
"node_modules/@types/semver": {
|
||||
"version": "7.3.13",
|
||||
@@ -16582,7 +16591,6 @@
|
||||
"version": "3.27.1",
|
||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.27.1.tgz",
|
||||
"integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww==",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
@@ -21336,6 +21344,14 @@
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/html-parse-stringify": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/html-parse-stringify/-/html-parse-stringify-3.0.1.tgz",
|
||||
"integrity": "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==",
|
||||
"dependencies": {
|
||||
"void-elements": "3.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/html-tags": {
|
||||
"version": "3.2.0",
|
||||
"resolved": "https://registry.npmjs.org/html-tags/-/html-tags-3.2.0.tgz",
|
||||
@@ -21478,6 +21494,34 @@
|
||||
"integrity": "sha512-iimHkHPfIAQ8zCDQLgn08pRqSVioyWvnGfaQ8gond2wf7Jq2jJ+24ykmnRyiz3fIldcn4oUuQXpjqKLhSVR7lw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/i18next": {
|
||||
"version": "22.4.9",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-22.4.9.tgz",
|
||||
"integrity": "sha512-8gWMmUz460KJDQp/ob3MNUX84cVuDRY9PLFPnV8d+Qezz/6dkjxwOaH70xjrCNDO+JrUL25iXfAIN9wUkInNZw==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://locize.com"
|
||||
},
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://locize.com/i18next.html"
|
||||
},
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
|
||||
}
|
||||
],
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.20.6"
|
||||
}
|
||||
},
|
||||
"node_modules/i18next-fs-backend": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/i18next-fs-backend/-/i18next-fs-backend-2.1.1.tgz",
|
||||
"integrity": "sha512-FTnj+UmNgT3YRml5ruRv0jMZDG7odOL/OP5PF5mOqvXud2vHrPOOs68Zdk6iqzL47cnnM0ZVkK2BAvpFeDJToA=="
|
||||
},
|
||||
"node_modules/iconv-lite": {
|
||||
"version": "0.4.24",
|
||||
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
|
||||
@@ -27555,6 +27599,45 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/next-i18next": {
|
||||
"version": "13.0.3",
|
||||
"resolved": "https://registry.npmjs.org/next-i18next/-/next-i18next-13.0.3.tgz",
|
||||
"integrity": "sha512-7AA8J6WbkxRBtSf1+97LSAE7btxWZHsBIJEJ3FuTSBgYtpRiO5NGjcb8XbNAlz6yGU0TtS+yZE+/Wu83KhIT1Q==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://locize.com/i18next.html"
|
||||
},
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
|
||||
},
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://locize.com"
|
||||
},
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://github.com/belgattitude"
|
||||
}
|
||||
],
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.20.6",
|
||||
"@types/hoist-non-react-statics": "^3.3.1",
|
||||
"core-js": "^3",
|
||||
"hoist-non-react-statics": "^3.3.2",
|
||||
"i18next-fs-backend": "^2.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"i18next": "^22.0.6",
|
||||
"next": ">= 12.0.0",
|
||||
"react": ">= 17.0.2",
|
||||
"react-i18next": "^12.1.1"
|
||||
}
|
||||
},
|
||||
"node_modules/next/node_modules/postcss": {
|
||||
"version": "8.4.14",
|
||||
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.14.tgz",
|
||||
@@ -32529,6 +32612,27 @@
|
||||
"react": "^16.8.0 || ^17 || ^18"
|
||||
}
|
||||
},
|
||||
"node_modules/react-i18next": {
|
||||
"version": "12.1.4",
|
||||
"resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-12.1.4.tgz",
|
||||
"integrity": "sha512-XQND7jYtgM7ht5PH3yIZljCRpAMTlH/zmngM9ZjToqa+0BR6xuu8c7QF0WIIOEjcMTB2S3iOfpN/xG/ZrAnO6g==",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.20.6",
|
||||
"html-parse-stringify": "^3.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"i18next": ">= 19.0.0",
|
||||
"react": ">= 16.8.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"react-dom": {
|
||||
"optional": true
|
||||
},
|
||||
"react-native": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
@@ -36019,6 +36123,11 @@
|
||||
"imurmurhash": "^0.1.4"
|
||||
}
|
||||
},
|
||||
"node_modules/unique-username-generator": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
|
||||
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
|
||||
},
|
||||
"node_modules/unist-builder": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
|
||||
@@ -36539,6 +36648,14 @@
|
||||
"integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/void-elements": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/void-elements/-/void-elements-3.1.0.tgz",
|
||||
"integrity": "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/w3c-xmlserializer": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-4.0.0.tgz",
|
||||
@@ -46906,6 +47023,15 @@
|
||||
"@types/unist": "*"
|
||||
}
|
||||
},
|
||||
"@types/hoist-non-react-statics": {
|
||||
"version": "3.3.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/hoist-non-react-statics/-/hoist-non-react-statics-3.3.1.tgz",
|
||||
"integrity": "sha512-iMIqiko6ooLrTh1joXodJK5X9xeEALT1kM5G3ZLhD3hszxBdIEd5C75U834D9mLcINgD4OyZf5uQXjkuYydWvA==",
|
||||
"requires": {
|
||||
"@types/react": "*",
|
||||
"hoist-non-react-statics": "^3.3.0"
|
||||
}
|
||||
},
|
||||
"@types/html-minifier-terser": {
|
||||
"version": "6.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz",
|
||||
@@ -47098,8 +47224,7 @@
|
||||
"@types/prop-types": {
|
||||
"version": "15.7.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.5.tgz",
|
||||
"integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==",
|
||||
"devOptional": true
|
||||
"integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w=="
|
||||
},
|
||||
"@types/qs": {
|
||||
"version": "6.9.7",
|
||||
@@ -47111,7 +47236,6 @@
|
||||
"version": "18.0.26",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.0.26.tgz",
|
||||
"integrity": "sha512-hCR3PJQsAIXyxhTNSiDFY//LhnMZWpNNr5etoCqx/iUfGc5gXWtQR2Phl908jVR6uPXacojQWTg4qRpkxTuGug==",
|
||||
"devOptional": true,
|
||||
"requires": {
|
||||
"@types/prop-types": "*",
|
||||
"@types/scheduler": "*",
|
||||
@@ -47130,8 +47254,7 @@
|
||||
"@types/scheduler": {
|
||||
"version": "0.16.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.16.2.tgz",
|
||||
"integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew==",
|
||||
"devOptional": true
|
||||
"integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew=="
|
||||
},
|
||||
"@types/semver": {
|
||||
"version": "7.3.13",
|
||||
@@ -50004,8 +50127,7 @@
|
||||
"core-js": {
|
||||
"version": "3.27.1",
|
||||
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.27.1.tgz",
|
||||
"integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww==",
|
||||
"dev": true
|
||||
"integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww=="
|
||||
},
|
||||
"core-js-compat": {
|
||||
"version": "3.27.1",
|
||||
@@ -53725,6 +53847,14 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"html-parse-stringify": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/html-parse-stringify/-/html-parse-stringify-3.0.1.tgz",
|
||||
"integrity": "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==",
|
||||
"requires": {
|
||||
"void-elements": "3.1.0"
|
||||
}
|
||||
},
|
||||
"html-tags": {
|
||||
"version": "3.2.0",
|
||||
"resolved": "https://registry.npmjs.org/html-tags/-/html-tags-3.2.0.tgz",
|
||||
@@ -53825,6 +53955,20 @@
|
||||
"integrity": "sha512-iimHkHPfIAQ8zCDQLgn08pRqSVioyWvnGfaQ8gond2wf7Jq2jJ+24ykmnRyiz3fIldcn4oUuQXpjqKLhSVR7lw==",
|
||||
"dev": true
|
||||
},
|
||||
"i18next": {
|
||||
"version": "22.4.9",
|
||||
"resolved": "https://registry.npmjs.org/i18next/-/i18next-22.4.9.tgz",
|
||||
"integrity": "sha512-8gWMmUz460KJDQp/ob3MNUX84cVuDRY9PLFPnV8d+Qezz/6dkjxwOaH70xjrCNDO+JrUL25iXfAIN9wUkInNZw==",
|
||||
"peer": true,
|
||||
"requires": {
|
||||
"@babel/runtime": "^7.20.6"
|
||||
}
|
||||
},
|
||||
"i18next-fs-backend": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/i18next-fs-backend/-/i18next-fs-backend-2.1.1.tgz",
|
||||
"integrity": "sha512-FTnj+UmNgT3YRml5ruRv0jMZDG7odOL/OP5PF5mOqvXud2vHrPOOs68Zdk6iqzL47cnnM0ZVkK2BAvpFeDJToA=="
|
||||
},
|
||||
"iconv-lite": {
|
||||
"version": "0.4.24",
|
||||
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
|
||||
@@ -58442,6 +58586,18 @@
|
||||
"uuid": "^8.3.2"
|
||||
}
|
||||
},
|
||||
"next-i18next": {
|
||||
"version": "13.0.3",
|
||||
"resolved": "https://registry.npmjs.org/next-i18next/-/next-i18next-13.0.3.tgz",
|
||||
"integrity": "sha512-7AA8J6WbkxRBtSf1+97LSAE7btxWZHsBIJEJ3FuTSBgYtpRiO5NGjcb8XbNAlz6yGU0TtS+yZE+/Wu83KhIT1Q==",
|
||||
"requires": {
|
||||
"@babel/runtime": "^7.20.6",
|
||||
"@types/hoist-non-react-statics": "^3.3.1",
|
||||
"core-js": "^3",
|
||||
"hoist-non-react-statics": "^3.3.2",
|
||||
"i18next-fs-backend": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"nice-try": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/nice-try/-/nice-try-1.0.5.tgz",
|
||||
@@ -61933,6 +62089,15 @@
|
||||
"integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==",
|
||||
"requires": {}
|
||||
},
|
||||
"react-i18next": {
|
||||
"version": "12.1.4",
|
||||
"resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-12.1.4.tgz",
|
||||
"integrity": "sha512-XQND7jYtgM7ht5PH3yIZljCRpAMTlH/zmngM9ZjToqa+0BR6xuu8c7QF0WIIOEjcMTB2S3iOfpN/xG/ZrAnO6g==",
|
||||
"requires": {
|
||||
"@babel/runtime": "^7.20.6",
|
||||
"html-parse-stringify": "^3.0.1"
|
||||
}
|
||||
},
|
||||
"react-icons": {
|
||||
"version": "4.7.1",
|
||||
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
|
||||
@@ -64641,6 +64806,11 @@
|
||||
"imurmurhash": "^0.1.4"
|
||||
}
|
||||
},
|
||||
"unique-username-generator": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
|
||||
"integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
|
||||
},
|
||||
"unist-builder": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
|
||||
@@ -65014,6 +65184,11 @@
|
||||
"integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==",
|
||||
"dev": true
|
||||
},
|
||||
"void-elements": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/void-elements/-/void-elements-3.1.0.tgz",
|
||||
"integrity": "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w=="
|
||||
},
|
||||
"w3c-xmlserializer": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-4.0.0.tgz",
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
"install": "^0.13.0",
|
||||
"next": "13.0.6",
|
||||
"next-auth": "^4.18.6",
|
||||
"next-i18next": "^13.0.3",
|
||||
"nodemailer": "^6.8.0",
|
||||
"npm": "^9.2.0",
|
||||
"postcss-focus-visible": "^7.1.0",
|
||||
@@ -58,11 +59,13 @@
|
||||
"react-dom": "18.2.0",
|
||||
"react-feature-flags": "^1.0.0",
|
||||
"react-hook-form": "^7.42.1",
|
||||
"react-i18next": "^12.1.4",
|
||||
"react-icons": "^4.7.1",
|
||||
"react-table": "^7.8.0",
|
||||
"sharp": "^0.31.3",
|
||||
"swr": "^2.0.0",
|
||||
"tailwindcss": "^3.2.4",
|
||||
"unique-username-generator": "^1.1.3",
|
||||
"use-debounce": "^9.0.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"discord": "Discord",
|
||||
"github": "GitHub"
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"title": "Open Assistant",
|
||||
"subtitle": "Conversational AI for everyone.",
|
||||
"description": "Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world.",
|
||||
"blurb": "We believe we can create a revolution.",
|
||||
"blurb1": "In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve the world by providing amazing conversational AI.",
|
||||
"join_us_title": "Join us",
|
||||
"join_us_description": "All open source projects begin with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us here:",
|
||||
"faq_title": "Frequently Asked Questions",
|
||||
"faq_items": {
|
||||
"q0": "How far along is this project?",
|
||||
"a0": "We are in the early stages of development, working from established research in applying RLHF to large language models.",
|
||||
"q1": "Who is behind Open Assistant?",
|
||||
"a1": "Open Assistant is a project organized by LAION and individuals around the world interested in bringing this technology to everyone."
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
import { Box, useColorMode } from "@chakra-ui/react";
|
||||
import React, { useId } from "react";
|
||||
|
||||
export const AnimatedCircles = () => {
|
||||
const id = useId();
|
||||
const { colorMode } = useColorMode();
|
||||
const baseRingColor = colorMode === "light" ? "#d4d4d4" : "#005a69";
|
||||
const gradStopColor = colorMode === "light" ? "#06b6d4" : "#00f2ff";
|
||||
|
||||
return (
|
||||
<Box className="absolute left-1/2 top-4 h-[1026px] w-[1026px] -translate-x-1/3 stroke-gray-300/70 [mask-image:linear-gradient(to_bottom,white_20%,transparent_75%)] sm:top-16 sm:-translate-x-1/2 lg:-top-16 lg:ml-12 xl:-top-14 xl:ml-0">
|
||||
<svg
|
||||
viewBox="0 0 1026 1026"
|
||||
fill="none"
|
||||
aria-hidden="true"
|
||||
className="absolute inset-0 h-full w-full animate-spin-slow"
|
||||
>
|
||||
<path
|
||||
d="M1025 513c0 282.77-229.23 512-512 512S1 795.77 1 513 230.23 1 513 1s512 229.23 512 512Z"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M513 1025C230.23 1025 1 795.77 1 513" stroke={`url(#${id}-gradient-1)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`${id}-gradient-1`} x1="1" y1="513" x2="1" y2="1025" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
<svg
|
||||
viewBox="0 0 1026 1026"
|
||||
fill="none"
|
||||
aria-hidden="true"
|
||||
className="absolute inset-0 h-full w-full animate-spin-reverse-slower"
|
||||
>
|
||||
<path
|
||||
d="M913 513c0 220.914-179.086 400-400 400S113 733.914 113 513s179.086-400 400-400 400 179.086 400 400Z"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M913 513c0 220.914-179.086 400-400 400" stroke={`url(#${id}-gradient-2)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`gradient-2`} x1="913" y1="513" x2="913" y2="913" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1 @@
|
||||
export { AnimatedCircles } from "./AnimatedCircles";
|
||||
@@ -1,9 +1,14 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Box, Link, Text, useColorMode } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { useId } from "react";
|
||||
import { FaDiscord, FaGithub } from "react-icons/fa";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
function CircleBackground({ width = 558, height = 558, ...props }) {
|
||||
const CIRCLE_HEIGHT = 558;
|
||||
const CIRCLE_WIDTH = 558;
|
||||
|
||||
function CircleBackground() {
|
||||
const id = useId();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
@@ -11,7 +16,14 @@ function CircleBackground({ width = 558, height = 558, ...props }) {
|
||||
const gradStopColor = colorMode === "light" ? "#fff" : "#000";
|
||||
|
||||
return (
|
||||
<svg viewBox="0 0 558 558" width={width} height={height} fill="none" aria-hidden="true" {...props}>
|
||||
<svg
|
||||
viewBox={`0 0 ${CIRCLE_HEIGHT} ${CIRCLE_WIDTH}`}
|
||||
width={CIRCLE_WIDTH}
|
||||
height={CIRCLE_HEIGHT}
|
||||
fill="none"
|
||||
aria-hidden="true"
|
||||
className="animate-spin-slower"
|
||||
>
|
||||
<defs>
|
||||
<linearGradient id={id} x1="79" y1="16" x2="105" y2="237" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
@@ -30,66 +42,54 @@ function CircleBackground({ width = 558, height = 558, ...props }) {
|
||||
|
||||
export function CallToAction() {
|
||||
const { colorMode } = useColorMode();
|
||||
const { t } = useTranslation();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-900" : "bg-gray-50";
|
||||
const headingColorClass = colorMode === "light" ? "text-white" : "text-black";
|
||||
const textColorClass = colorMode === "light" ? "text-gray-300" : "text-black";
|
||||
|
||||
return (
|
||||
<section id="join-us" className={`relative overflow-hidden py-20 sm:py-28 ${bgColorClass} ${textColorClass}`}>
|
||||
<div className="absolute top-1/2 left-20 -translate-y-1/2 sm:left-1/2 sm:-translate-x-1/2">
|
||||
<CircleBackground className="animate-spin-slower" />
|
||||
</div>
|
||||
<Box
|
||||
as="section"
|
||||
id="join-us"
|
||||
className={`relative overflow-hidden py-20 sm:py-28 ${bgColorClass} ${textColorClass}`}
|
||||
>
|
||||
<Box className="absolute top-1/2 left-20 -translate-y-1/2 sm:left-1/2 sm:-translate-x-1/2">
|
||||
<CircleBackground />
|
||||
</Box>
|
||||
<Container className="relative">
|
||||
<div className="mx-auto max-w-md sm:text-center">
|
||||
<h2 className={`text-3xl font-medium tracking-tight sm:text-4xl ${headingColorClass}`}>Join Us</h2>
|
||||
<p className="mt-4 text-lg">
|
||||
All open source projects begin with people like you. Open source is the belief that if we collaborate we can
|
||||
together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us
|
||||
here:
|
||||
</p>
|
||||
<div className="mt-8 flex justify-center">
|
||||
<a href="https://ykilcher.com/open-assistant-discord" rel="noreferrer" target="_blank">
|
||||
<Box className="mx-auto max-w-md sm:text-center">
|
||||
<Text as="h2" className={`text-3xl font-medium tracking-tight sm:text-4xl ${headingColorClass}`}>
|
||||
{t("index:join_us_title")}
|
||||
</Text>
|
||||
<Text as="p" className="mt-4 text-lg">
|
||||
{t("index:join_us_description")}
|
||||
</Text>
|
||||
<Box className="mt-8 flex justify-center">
|
||||
<Link href="https://ykilcher.com/open-assistant-discord" rel="noreferrer" target="_blank">
|
||||
<button
|
||||
type="button"
|
||||
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
||||
>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 640 512" className="w-6 h-6">
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M524.531,69.836a1.5,1.5,0,0,0-.764-.7A485.065,485.065,0,0,0,404.081,32.03a1.816,1.816,0,0,0-1.923.91,337.461,337.461,0,0,0-14.9,30.6,447.848,447.848,0,0,0-134.426,0,309.541,309.541,0,0,0-15.135-30.6,1.89,1.89,0,0,0-1.924-.91A483.689,483.689,0,0,0,116.085,69.137a1.712,1.712,0,0,0-.788.676C39.068,183.651,18.186,294.69,28.43,404.354a2.016,2.016,0,0,0,.765,1.375A487.666,487.666,0,0,0,176.02,479.918a1.9,1.9,0,0,0,2.063-.676A348.2,348.2,0,0,0,208.12,430.4a1.86,1.86,0,0,0-1.019-2.588,321.173,321.173,0,0,1-45.868-21.853,1.885,1.885,0,0,1-.185-3.126c3.082-2.309,6.166-4.711,9.109-7.137a1.819,1.819,0,0,1,1.9-.256c96.229,43.917,200.41,43.917,295.5,0a1.812,1.812,0,0,1,1.924.233c2.944,2.426,6.027,4.851,9.132,7.16a1.884,1.884,0,0,1-.162,3.126,301.407,301.407,0,0,1-45.89,21.83,1.875,1.875,0,0,0-1,2.611,391.055,391.055,0,0,0,30.014,48.815,1.864,1.864,0,0,0,2.063.7A486.048,486.048,0,0,0,610.7,405.729a1.882,1.882,0,0,0,.765-1.352C623.729,277.594,590.933,167.465,524.531,69.836ZM222.491,337.58c-28.972,0-52.844-26.587-52.844-59.239S193.056,219.1,222.491,219.1c29.665,0,53.306,26.82,52.843,59.239C275.334,310.993,251.924,337.58,222.491,337.58Zm195.38,0c-28.971,0-52.843-26.587-52.843-59.239S388.437,219.1,417.871,219.1c29.667,0,53.307,26.82,52.844,59.239C470.715,310.993,447.538,337.58,417.871,337.58Z"
|
||||
/>
|
||||
</svg>
|
||||
<span className="text-lg ml-3">Discord</span>
|
||||
<FaDiscord size={25} />
|
||||
<Text as="span" className="text-lg ml-3">
|
||||
{t("discord")}
|
||||
</Text>
|
||||
</button>
|
||||
</a>
|
||||
|
||||
<a href="https://github.com/LAION-AI/Open-Assistant" rel="noreferrer" target="_blank">
|
||||
</Link>
|
||||
<Link href="https://github.com/LAION-AI/Open-Assistant" rel="noreferrer" target="_blank">
|
||||
<button
|
||||
type="button"
|
||||
className="mb-2 ml-6 flex items-center rounded-md border border-transparent bg-blue-600 px-6 py-3 text-base font-medium text-white shadow-sm hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
||||
>
|
||||
<svg
|
||||
className="mr-2 -ml-1 w-6 h-6"
|
||||
aria-hidden="true"
|
||||
focusable="false"
|
||||
data-prefix="fab"
|
||||
data-icon="github"
|
||||
role="img"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 496 512"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3 .3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6zm-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5 .3-6.2 2.3zm44.2-1.7c-2.9 .7-4.9 2.6-4.6 4.9 .3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9zM244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8zM97.2 352.9c-1.3 1-1 3.3 .7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1zm-10.8-8.1c-.7 1.3 .3 2.9 2.3 3.9 1.6 1 3.6 .7 4.3-.7 .7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3 .7zm32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3 .7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1zm-11.4-14.7c-1.6 1-1.6 3.6 0 5.9 1.6 2.3 4.3 3.3 5.6 2.3 1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2z"
|
||||
></path>
|
||||
</svg>
|
||||
|
||||
<span className="text-lg ml-1">Github</span>
|
||||
<FaGithub size={25} />
|
||||
<Text as="span" className="text-lg ml-3">
|
||||
{t("github")}
|
||||
</Text>
|
||||
</button>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</Link>
|
||||
</Box>
|
||||
</Box>
|
||||
</Container>
|
||||
</section>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ export function SlimFooter() {
|
||||
<FooterLink href="/terms-of-service" label="Terms of Service" />
|
||||
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
|
||||
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
|
||||
<FooterLink href="https://projects.laion.ai/Open-Assistant/" label="Docs" />
|
||||
</Box>
|
||||
</nav>
|
||||
</Box>
|
||||
|
||||
@@ -28,7 +28,3 @@ export const EmptyState = (props: EmptyStateProps) => {
|
||||
export const TaskEmptyState = () => {
|
||||
return <EmptyState text="Looks like no tasks were found." icon={FiAlertTriangle} />;
|
||||
};
|
||||
|
||||
export const PageEmptyState = () => {
|
||||
return <EmptyState text="Sorry, the page you are looking for does not exist." icon={FiAlertTriangle} />;
|
||||
};
|
||||
|
||||
@@ -1,73 +1,42 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Box, List, ListItem, Text, useColorMode } from "@chakra-ui/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
const faqs = [
|
||||
[
|
||||
{
|
||||
question: "How far along is this project?",
|
||||
answer:
|
||||
"We are in the early stages of development, working from established research in applying RLHF to large language models.",
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
question: "Who is behind Open Assistant?",
|
||||
answer:
|
||||
"Open Assistant is a project organized by LAION and individuals around the world interested in bringing this technology to everyone.",
|
||||
},
|
||||
],
|
||||
[
|
||||
// {
|
||||
// question: 'Where can I learn more?',
|
||||
// answer:
|
||||
// 'Please feel free to reach out to us on Discord. We are happy to answer any questions you may have.',
|
||||
// },
|
||||
],
|
||||
];
|
||||
const FAQS = Array.from({ length: 2 });
|
||||
|
||||
export function Faq() {
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const { t } = useTranslation("index");
|
||||
const headingColorClass = colorMode === "light" ? "text-gray-900" : "text-white";
|
||||
const textColorClass = colorMode === "light" ? "text-gray-700" : "text-gray-100";
|
||||
|
||||
return (
|
||||
<section id="faq" aria-labelledby="faqs-title" className="border-t border-gray-200 py-20 sm:py-32">
|
||||
<Box as="section" id="faq" aria-labelledby="faqs-title" className="border-t border-gray-200 py-20 sm:py-32">
|
||||
<Container className="">
|
||||
<div className="mx-auto max-w-2xl lg:mx-0">
|
||||
<h2 id="faqs-title" className={`text-3xl font-medium tracking-tight ${headingColorClass}`}>
|
||||
Frequently Asked Questions
|
||||
</h2>
|
||||
{/* <p className="mt-2 text-lg text-gray-600">
|
||||
If you have anything else you want to ask,{' '}
|
||||
<Link
|
||||
href="mailto:info@open-assistant.tech"
|
||||
className="text-gray-900 underline"
|
||||
>
|
||||
reach out to us
|
||||
</Link>
|
||||
.
|
||||
</p> */}
|
||||
</div>
|
||||
<ul
|
||||
<Box className="mx-auto max-w-2xl lg:mx-0">
|
||||
<Text as="h2" id="faqs-title" className={`text-3xl font-medium tracking-tight ${headingColorClass}`}>
|
||||
{t("faq_title")}
|
||||
</Text>
|
||||
</Box>
|
||||
<List
|
||||
role="list"
|
||||
className="mx-auto mt-16 grid max-w-2xl grid-cols-1 gap-8 sm:mt-20 lg:max-w-none lg:grid-cols-3"
|
||||
>
|
||||
{faqs.map((column, columnIndex) => (
|
||||
<li key={columnIndex}>
|
||||
<ul role="list" className="space-y-10">
|
||||
{column.map((faq, faqIndex) => (
|
||||
<li key={faqIndex}>
|
||||
<h3 className={`text-lg font-semibold leading-6 ${headingColorClass}`}>{faq.question}</h3>
|
||||
<p className={`mt-4 text-sm ${textColorClass}`}>{faq.answer}</p>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
{FAQS.map((_, index) => {
|
||||
return (
|
||||
<ListItem className="space-y-10" key={`question_${index}`}>
|
||||
<Text as="h3" className={`text-lg font-semibold leading-6 ${headingColorClass}`}>
|
||||
{t(`faq_items.q${index}`)}
|
||||
</Text>
|
||||
<Text as="p" className={`mt-4 text-sm ${textColorClass}`}>
|
||||
{t(`faq_items.a${index}`)}
|
||||
</Text>
|
||||
</ListItem>
|
||||
);
|
||||
})}
|
||||
</List>
|
||||
</Container>
|
||||
</section>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ export function Footer() {
|
||||
</Flex>
|
||||
|
||||
<nav>
|
||||
<Box display="flex" flexDirection={["column", "row"]} gap={["6", "14"]} alignItems="center" fontSize="sm">
|
||||
<Box display="flex" flexDirection={["column", "row"]} gap={["6", "14"]} fontSize="sm">
|
||||
<Flex direction="column" alignItems={["center", "start"]}>
|
||||
<Text fontWeight="bold" color={textColor}>
|
||||
Legal
|
||||
@@ -57,6 +57,12 @@ export function Footer() {
|
||||
<FooterLink href="https://github.com/LAION-AI/Open-Assistant" label="Github" />
|
||||
<FooterLink href="https://ykilcher.com/open-assistant-discord" label="Discord" />
|
||||
</Flex>
|
||||
<Flex direction="column" alignItems={["center", "start"]}>
|
||||
<Text fontWeight="bold" color={textColor}>
|
||||
About
|
||||
</Text>
|
||||
<FooterLink href="https://projects.laion.ai/Open-Assistant" label="Docs" />
|
||||
</Flex>
|
||||
</Box>
|
||||
</nav>
|
||||
</Box>
|
||||
|
||||
@@ -74,7 +74,7 @@ export function UserMenu() {
|
||||
<Box display="flex" alignItems="center" gap="3" p="1" paddingRight={[1, 1, 1, 6, 6]}>
|
||||
<Avatar size="sm" bgImage={session.user.image}></Avatar>
|
||||
<Text data-cy="username" className="hidden lg:flex">
|
||||
{session.user.name || session.user.email}
|
||||
{session.user.name || "New User"}
|
||||
</Text>
|
||||
</Box>
|
||||
</MenuButton>
|
||||
|
||||
@@ -1,87 +1,36 @@
|
||||
import { useColorMode } from "@chakra-ui/react";
|
||||
import { Box, Text, useColorMode } from "@chakra-ui/react";
|
||||
import Image from "next/image";
|
||||
import { useId } from "react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
|
||||
import { Container } from "./Container";
|
||||
|
||||
function BackgroundIllustration(props) {
|
||||
const id = useId();
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const baseRingColor = colorMode === "light" ? "#d4d4d4" : "#005a69";
|
||||
const gradStopColor = colorMode === "light" ? "#06b6d4" : "#00f2ff";
|
||||
|
||||
return (
|
||||
<div {...props}>
|
||||
<svg
|
||||
viewBox="0 0 1026 1026"
|
||||
fill="none"
|
||||
aria-hidden="true"
|
||||
className="absolute inset-0 h-full w-full animate-spin-slow"
|
||||
>
|
||||
<path
|
||||
d="M1025 513c0 282.77-229.23 512-512 512S1 795.77 1 513 230.23 1 513 1s512 229.23 512 512Z"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M513 1025C230.23 1025 1 795.77 1 513" stroke={`url(#${id}-gradient-1)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`${id}-gradient-1`} x1="1" y1="513" x2="1" y2="1025" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
<svg
|
||||
viewBox="0 0 1026 1026"
|
||||
fill="none"
|
||||
aria-hidden="true"
|
||||
className="absolute inset-0 h-full w-full animate-spin-reverse-slower"
|
||||
>
|
||||
<path
|
||||
d="M913 513c0 220.914-179.086 400-400 400S113 733.914 113 513s179.086-400 400-400 400 179.086 400 400Z"
|
||||
stroke={baseRingColor}
|
||||
strokeOpacity="0.7"
|
||||
/>
|
||||
<path d="M913 513c0 220.914-179.086 400-400 400" stroke={`url(#${id}-gradient-2)`} strokeLinecap="round" />
|
||||
<defs>
|
||||
<linearGradient id={`${id}-gradient-2`} x1="913" y1="513" x2="913" y2="913" gradientUnits="userSpaceOnUse">
|
||||
<stop stopColor={gradStopColor} />
|
||||
<stop offset="1" stopColor={gradStopColor} stopOpacity="0" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
import { AnimatedCircles } from "./AnimatedCircles";
|
||||
|
||||
export function Hero() {
|
||||
const { t } = useTranslation("index");
|
||||
const { colorMode } = useColorMode();
|
||||
const pTextColor = colorMode === "light" ? "text-gray-600" : "text-white";
|
||||
const fancyTextGradientClasses =
|
||||
colorMode === "light" ? "from-blue-600 via-sky-400 to-blue-700" : "from-blue-500 via-sky-300 to-blue-400";
|
||||
|
||||
return (
|
||||
<div className="overflow-hidden py-20 sm:py-32 lg:pb-32 xl:pb-36">
|
||||
<Box className="overflow-hidden py-20 sm:py-32 lg:pb-32 xl:pb-36">
|
||||
<Container className="">
|
||||
<div className="lg:grid lg:grid-cols-12 lg:gap-x-8 lg:gap-y-20">
|
||||
<div className="relative mx-auto max-w-2xl lg:col-span-7 lg:max-w-none lg:pt-6 xl:col-span-6">
|
||||
<h1 className="text-5xl mb-6 font-bold tracking-tight">Open Assistant</h1>
|
||||
<p
|
||||
className={`bg-gradient-to-r ${fancyTextGradientClasses} mt-8 text-3xl inline bg-clip-text font-display tracking-tight text-transparent`}
|
||||
<Box className="lg:grid lg:grid-cols-12 lg:gap-x-8 lg:gap-y-20">
|
||||
<Box className="relative mx-auto max-w-2xl lg:col-span-7 lg:max-w-none lg:pt-6 xl:col-span-6">
|
||||
<Text as="h1" className="text-5xl mb-6 font-bold tracking-tight">
|
||||
{t("title")}
|
||||
</Text>
|
||||
<Text
|
||||
as="h2"
|
||||
className={`bg-gradient-to-r ${fancyTextGradientClasses} font-bold mt-8 text-3xl inline bg-clip-text font-display tracking-tight text-transparent`}
|
||||
>
|
||||
<b>Conversational AI for everyone.</b>
|
||||
</p>
|
||||
<p className={`mt-6 text-lg ${pTextColor}`}>We believe we can create a revolution.</p>
|
||||
<p className={`mt-6 text-lg ${pTextColor}`}>
|
||||
In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve
|
||||
the world by providing amazing conversational AI.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="relative mt-10 sm:mt-20 lg:col-span-5 lg:row-span-2 lg:mt-0 xl:col-span-6">
|
||||
<BackgroundIllustration className="absolute left-1/2 top-4 h-[1026px] w-[1026px] -translate-x-1/3 stroke-gray-300/70 [mask-image:linear-gradient(to_bottom,white_20%,transparent_75%)] sm:top-16 sm:-translate-x-1/2 lg:-top-16 lg:ml-12 xl:-top-14 xl:ml-0" />
|
||||
<div className="-mx-4 h-[448px] px-9 [mask-image:linear-gradient(to_bottom,white_60%,transparent)] sm:mx-0 lg:absolute lg:-inset-x-10 lg:-top-10 lg:-bottom-20 lg:h-auto lg:px-0 lg:pt-10 xl:-bottom-32">
|
||||
{t("subtitle")}
|
||||
</Text>
|
||||
<Text className={`mt-6 text-lg ${pTextColor}`}>{t("blurb")}</Text>
|
||||
<Text className={`mt-6 text-lg ${pTextColor}`}>{t("blurb1")}</Text>
|
||||
</Box>
|
||||
<Box className="relative mt-10 sm:mt-20 lg:col-span-5 lg:row-span-2 lg:mt-0 xl:col-span-6">
|
||||
<AnimatedCircles />
|
||||
<Box className="-mx-4 h-[448px] px-9 [mask-image:linear-gradient(to_bottom,white_60%,transparent)] sm:mx-0 lg:absolute lg:-inset-x-10 lg:-top-10 lg:-bottom-20 lg:h-auto lg:px-0 lg:pt-10 xl:-bottom-32">
|
||||
<Image
|
||||
src="/images/logos/logo.svg"
|
||||
className="mx-auto mr-6 object-fill"
|
||||
@@ -89,10 +38,10 @@ export function Hero() {
|
||||
height="450"
|
||||
alt={""}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
</Container>
|
||||
</div>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -30,7 +30,12 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
|
||||
{props.enabled ? (
|
||||
<Box width={["full", "full", "full", "fit-content"]} maxWidth={["full", "full", "full", "2xl"]}>
|
||||
<Link href={`/messages/${item.id}`}>
|
||||
<LinkBox bg={item.is_assistant ? backgroundColor : backgroundColor2} p="4" borderRadius="md">
|
||||
<LinkBox
|
||||
bg={item.is_assistant ? backgroundColor : backgroundColor2}
|
||||
p="4"
|
||||
borderRadius="md"
|
||||
whiteSpace="pre-line"
|
||||
>
|
||||
{item.text}
|
||||
</LinkBox>
|
||||
</Link>
|
||||
|
||||
@@ -24,7 +24,7 @@ interface LabelRadioGroupProps {
|
||||
|
||||
const label_messages: { [label: string]: { description: string; explanation: string[] } } = {
|
||||
spam: {
|
||||
description: "The message is spam?",
|
||||
description: "Is the message spam?",
|
||||
explanation: [
|
||||
'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".',
|
||||
"This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.",
|
||||
|
||||
@@ -59,8 +59,8 @@ function CheckboxSliderItem(props: {
|
||||
>
|
||||
<SliderTrack>
|
||||
<SliderFilledTrack />
|
||||
<SliderThumb />
|
||||
</SliderTrack>
|
||||
<SliderThumb bg="gainsboro" />
|
||||
</Slider>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -4,6 +4,7 @@ import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
|
||||
export const CreateTask = ({
|
||||
task,
|
||||
@@ -14,7 +15,6 @@ export const CreateTask = ({
|
||||
}: TaskSurveyProps<{ text: string }>) => {
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
const labelColor = useColorModeValue("gray.600", "gray.400");
|
||||
|
||||
const [inputText, setInputText] = useState("");
|
||||
const textChangeHandler = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
@@ -33,14 +33,7 @@ export const CreateTask = ({
|
||||
<div data-cy="task" data-task-type="create-task">
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<Stack spacing="1">
|
||||
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
|
||||
{taskType.label}
|
||||
</Text>
|
||||
<Text fontSize="md" color={labelColor}>
|
||||
{taskType.overview}
|
||||
</Text>
|
||||
</Stack>
|
||||
<TaskHeader taskType={taskType} />
|
||||
{task.conversation ? (
|
||||
<Box mt="4" borderRadius="lg" bg={cardColor} className="p-3 sm:p-6">
|
||||
<MessageTable messages={task.conversation.messages} />
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import { Box, Stack, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useEffect } from "react";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
import { Sortable } from "src/components/Sortable/Sortable";
|
||||
import { SurveyCard } from "src/components/Survey/SurveyCard";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
|
||||
export const EvaluateTask = ({
|
||||
task,
|
||||
taskType,
|
||||
isEditable,
|
||||
isDisabled,
|
||||
onReplyChanged,
|
||||
}: TaskSurveyProps<{ ranking: number[] }>) => {
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
const labelColor = useColorModeValue("gray.600", "gray.400");
|
||||
|
||||
let messages = [];
|
||||
if (task.conversation) {
|
||||
@@ -36,14 +36,7 @@ export const EvaluateTask = ({
|
||||
<div data-cy="task" data-task-type="evaluate-task">
|
||||
<Box mb="4">
|
||||
<SurveyCard>
|
||||
<Stack spacing="1">
|
||||
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
|
||||
Instructions
|
||||
</Text>
|
||||
<Text fontSize="md" color={labelColor}>
|
||||
Given the following {sortables}, sort them from best to worst, best being first, worst being last.
|
||||
</Text>
|
||||
</Stack>
|
||||
<TaskHeader taskType={taskType} />
|
||||
<Box mt="4" p="6" borderRadius="lg" bg={cardColor}>
|
||||
<MessageTable messages={messages} />
|
||||
</Box>
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Box } from "@chakra-ui/react";
|
||||
import { Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, useColorModeValue } from "@chakra-ui/react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { MessageView } from "src/components/Messages";
|
||||
import { MessageTable } from "src/components/Messages/MessageTable";
|
||||
@@ -7,6 +6,7 @@ import { LabelRadioGroup } from "src/components/Survey/LabelRadioGroup";
|
||||
import { LabelSliderGroup } from "src/components/Survey/LabelSliderGroup";
|
||||
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
|
||||
import { TaskSurveyProps } from "src/components/Tasks/Task";
|
||||
import { TaskHeader } from "src/components/Tasks/TaskHeader";
|
||||
import { TaskType } from "src/types/Task";
|
||||
|
||||
export const LabelTask = ({
|
||||
@@ -36,20 +36,12 @@ export const LabelTask = ({
|
||||
};
|
||||
|
||||
const cardColor = useColorModeValue("gray.50", "gray.800");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
const labelColor = useColorModeValue("gray.600", "gray.400");
|
||||
|
||||
return (
|
||||
<div data-cy="task" data-task-type="label-task">
|
||||
<TwoColumnsWithCards>
|
||||
<>
|
||||
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
|
||||
{taskType.label}
|
||||
</Text>
|
||||
<Text fontSize="md" color={labelColor}>
|
||||
{taskType.overview}
|
||||
</Text>
|
||||
|
||||
<TaskHeader taskType={taskType} />
|
||||
{task.conversation ? (
|
||||
<Box mt="4" p="6" borderRadius="lg" bg={cardColor}>
|
||||
<MessageTable
|
||||
|
||||
@@ -27,6 +27,8 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
const replyContent = useRef<TaskContent>(null);
|
||||
const [showUnchangedWarning, setShowUnchangedWarning] = useState(false);
|
||||
|
||||
const rootEl = useRef<HTMLDivElement>(null);
|
||||
|
||||
const taskType = TaskTypes.find((taskType) => taskType.type === task.type && taskType.mode === task.mode);
|
||||
|
||||
const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, {
|
||||
@@ -89,6 +91,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
content: replyContent.current,
|
||||
});
|
||||
setTaskStatus("SUBMITTED");
|
||||
scrollToTop(rootEl.current);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -138,7 +141,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div ref={rootEl}>
|
||||
{taskTypeComponent()}
|
||||
<TaskControls
|
||||
task={task}
|
||||
@@ -164,3 +167,10 @@ export const Task = ({ frontendId, task, trigger, mutate }) => {
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const scrollToTop = (element: HTMLElement) => {
|
||||
while (element) {
|
||||
element.scrollTop = 0;
|
||||
element = element.parentElement;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import { HStack, IconButton, Link, Stack, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { FiHelpCircle } from "react-icons/fi";
|
||||
import type { TaskInfo } from "src/components/Tasks/TaskTypes";
|
||||
|
||||
interface TaskHeaderProps {
|
||||
/**
|
||||
* The `TaskInfo` representing how we present the task to a user.
|
||||
*/
|
||||
taskType: TaskInfo;
|
||||
}
|
||||
|
||||
/**
|
||||
* Presents the Task label, instructions, and help link
|
||||
*/
|
||||
const TaskHeader = ({ taskType }: TaskHeaderProps) => {
|
||||
const labelColor = useColorModeValue("gray.600", "gray.400");
|
||||
const titleColor = useColorModeValue("gray.800", "gray.300");
|
||||
return (
|
||||
<Stack spacing="1">
|
||||
<HStack>
|
||||
<Text fontSize="xl" fontWeight="bold" color={titleColor}>
|
||||
{taskType.label}
|
||||
</Text>
|
||||
<Link href={taskType.help_link} isExternal>
|
||||
<IconButton variant="ghost" aria-label="More Information" icon={<FiHelpCircle />} />
|
||||
</Link>
|
||||
</HStack>
|
||||
<Text fontSize="md" color={labelColor}>
|
||||
{taskType.overview}
|
||||
</Text>
|
||||
</Stack>
|
||||
);
|
||||
};
|
||||
|
||||
export { TaskHeader };
|
||||
@@ -0,0 +1 @@
|
||||
export * from "./TaskHeader";
|
||||
@@ -11,6 +11,7 @@ export interface TaskInfo {
|
||||
category: TaskCategory;
|
||||
pathname: string;
|
||||
type: string;
|
||||
help_link: string;
|
||||
mode?: string;
|
||||
overview?: string;
|
||||
instruction?: string;
|
||||
@@ -26,6 +27,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Help us improve Open Assistant by starting a random task.",
|
||||
category: TaskCategory.Tasks,
|
||||
pathname: "/tasks/random",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
type: "random",
|
||||
update_type: "random",
|
||||
},
|
||||
@@ -35,6 +37,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Write initial prompts to help Open Assistant to try replying to diverse messages.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/initial_prompt",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
type: "initial_prompt",
|
||||
overview: "Create an initial message to send to the assistant",
|
||||
instruction: "Provide the initial prompt",
|
||||
@@ -45,6 +48,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Chat with Open Assistant and help improve it’s responses as you interact with it.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/user_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/reply_as_user",
|
||||
type: "prompter_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the user's reply",
|
||||
@@ -55,6 +59,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
category: TaskCategory.Create,
|
||||
pathname: "/create/assistant_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/reply_as_assistant",
|
||||
type: "assistant_reply",
|
||||
overview: "Given the following conversation, provide an adequate reply",
|
||||
instruction: "Provide the assistant's reply",
|
||||
@@ -66,6 +71,8 @@ export const TaskTypes: TaskInfo[] = [
|
||||
category: TaskCategory.Evaluate,
|
||||
desc: "Help Open Assistant improve its responses to conversations with other users.",
|
||||
pathname: "/evaluate/rank_user_replies",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Given the following User replies, sort them from best to worst, best being first, worst being last.",
|
||||
type: "rank_prompter_replies",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
@@ -76,6 +83,9 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_assistant_replies",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview:
|
||||
"Given the following Assistant replies, sort them from best to worst, best being first, worst being last.",
|
||||
type: "rank_assistant_replies",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
@@ -86,6 +96,8 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Score prompts given by Open Assistant based on their accuracy and readability.",
|
||||
category: TaskCategory.Evaluate,
|
||||
pathname: "/evaluate/rank_initial_prompts",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Given the following inital prompts, sort them from best to worst, best being first, worst being last.",
|
||||
type: "rank_initial_prompts",
|
||||
update_type: "message_ranking",
|
||||
unchanged_title: "Order Unchanged",
|
||||
@@ -97,6 +109,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Provide labels for the following prompt",
|
||||
type: "label_initial_prompt",
|
||||
mode: "full",
|
||||
@@ -107,6 +120,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_prompter_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt",
|
||||
type: "label_prompter_reply",
|
||||
mode: "full",
|
||||
@@ -117,6 +131,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/tasks/label_assistant_reply",
|
||||
overview: "Given the following discussion, provide labels for the final prompt.",
|
||||
type: "label_assistant_reply",
|
||||
mode: "full",
|
||||
@@ -128,6 +143,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_initial_prompt",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Read the following prompt and then answer the question about it.",
|
||||
type: "label_initial_prompt",
|
||||
mode: "simple",
|
||||
@@ -138,6 +154,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_prompter_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Read the following conversation and then answer the question about the last prompt in the discussion.",
|
||||
type: "label_prompter_reply",
|
||||
mode: "simple",
|
||||
@@ -148,6 +165,7 @@ export const TaskTypes: TaskInfo[] = [
|
||||
desc: "Provide labels for a prompt.",
|
||||
category: TaskCategory.Label,
|
||||
pathname: "/label/label_assistant_reply",
|
||||
help_link: "https://projects.laion.ai/Open-Assistant/docs/guides/prompting",
|
||||
overview: "Read the following conversation and then answer the question about the last prompt in the discussion.",
|
||||
type: "label_assistant_reply",
|
||||
mode: "simple",
|
||||
|
||||
@@ -113,7 +113,7 @@ export class OasstApiClient {
|
||||
type: taskType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name || userToken.email,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
});
|
||||
@@ -146,7 +146,7 @@ export class OasstApiClient {
|
||||
type: updateType,
|
||||
user: {
|
||||
id: userToken.sub,
|
||||
display_name: userToken.name || userToken.email,
|
||||
display_name: userToken.name,
|
||||
auth_method: "local",
|
||||
},
|
||||
task_id: taskId,
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { Box, Button, Center, Link, Text, useColorModeValue } from "@chakra-ui/react";
|
||||
import { Box, Button, Center, Link, Text } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { PageEmptyState } from "src/components/EmptyState";
|
||||
import { EmptyState } from "src/components/EmptyState";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
|
||||
function Error() {
|
||||
@@ -13,7 +12,7 @@ function Error() {
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<Center flexDirection="column" gap="4" fontSize="lg" className="subpixel-antialiased">
|
||||
<PageEmptyState />
|
||||
<EmptyState text="Sorry, the page you are looking for does not exist." icon={FiAlertTriangle} />
|
||||
<Box display="flex" flexDirection="column" alignItems="center" gap="2" mt="6">
|
||||
<Text fontSize="sm">If you were trying to contribute data but ended up here, please file a bug.</Text>
|
||||
<Button
|
||||
|
||||
+25
-11
@@ -1,32 +1,46 @@
|
||||
import { Button, Link, Stack } from "@chakra-ui/react";
|
||||
import { Box, Button, Center, Link, Text } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import NextLink from "next/link";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { EmptyState } from "src/components/EmptyState";
|
||||
import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
|
||||
export default function Error() {
|
||||
function ServerError() {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>500 - Open Assistant</title>
|
||||
<meta name="404" content="Sorry, this page doesn't exist." />
|
||||
</Head>
|
||||
<main className="flex h-3/4 items-center justify-center overflow-hidden subpixel-antialiased text-xl">
|
||||
<Stack>
|
||||
<p>Sorry, We encountered a server error. We're not sure what went wrong</p>
|
||||
<p>Please file a but below and describe what you were trying to accomplish</p>
|
||||
<Button leftIcon={<FiAlertTriangle className="text-blue-500" aria-hidden="true" />} variant="solid">
|
||||
<Center flexDirection="column" gap="4" fontSize="lg" className="subpixel-antialiased">
|
||||
<EmptyState
|
||||
text="Sorry, we encountered a server error. We're not sure what went wrong."
|
||||
icon={FiAlertTriangle}
|
||||
/>
|
||||
<Box display="flex" flexDirection="column" alignItems="center" gap="2" mt="6">
|
||||
<Text fontSize="sm">If you were trying to contribute data but ended up here, please file a bug.</Text>
|
||||
<Button
|
||||
width="fit-content"
|
||||
leftIcon={<FiAlertTriangle className="text-blue-500" aria-hidden="true" />}
|
||||
variant="solid"
|
||||
size="xs"
|
||||
>
|
||||
<Link
|
||||
as={NextLink}
|
||||
key="Report a Bug"
|
||||
href="https://github.com/LAION-AI/Open-Assistant/issues/new/choose"
|
||||
aria-label="Report a Bug"
|
||||
className="flex items-center"
|
||||
_hover={{ textDecoration: "none" }}
|
||||
isExternal
|
||||
>
|
||||
Report a Bug
|
||||
</Link>
|
||||
</Button>
|
||||
</Stack>
|
||||
</main>
|
||||
</Box>
|
||||
</Center>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
ServerError.getLayout = getTransparentHeaderLayout;
|
||||
|
||||
export default ServerError;
|
||||
|
||||
@@ -3,11 +3,13 @@ import "focus-visible";
|
||||
|
||||
import type { AppProps } from "next/app";
|
||||
import { SessionProvider } from "next-auth/react";
|
||||
import { appWithTranslation } from "next-i18next";
|
||||
import { FlagsProvider } from "react-feature-flags";
|
||||
import { getDefaultLayout, NextPageWithLayout } from "src/components/Layout";
|
||||
import flags from "src/flags";
|
||||
import { SWRConfig, SWRConfiguration } from "swr";
|
||||
|
||||
import nextI18NextConfig from "../../next-i18next.config.js";
|
||||
import { Chakra, getServerSideProps } from "../styles/Chakra";
|
||||
|
||||
type AppPropsWithLayout = AppProps & {
|
||||
@@ -34,4 +36,4 @@ function MyApp({ Component, pageProps: { session, cookies, ...pageProps } }: App
|
||||
);
|
||||
}
|
||||
export { getServerSideProps };
|
||||
export default MyApp;
|
||||
export default appWithTranslation(MyApp, nextI18NextConfig);
|
||||
|
||||
@@ -2,27 +2,11 @@ import { Button, Input, InputGroup } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import Router from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import React, { useState } from "react";
|
||||
import React from "react";
|
||||
import { Control, useForm, useWatch } from "react-hook-form";
|
||||
|
||||
export default function Account() {
|
||||
const { data: session } = useSession();
|
||||
const [username, setUsername] = useState("");
|
||||
|
||||
const updateUser = async (e: React.SyntheticEvent) => {
|
||||
e.preventDefault();
|
||||
try {
|
||||
const body = { username };
|
||||
await fetch("/api/username", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
session.user.name = username;
|
||||
await Router.push("/account");
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!session) {
|
||||
return;
|
||||
@@ -39,21 +23,52 @@ export default function Account() {
|
||||
<div className="oa-basic-theme">
|
||||
<main className="h-3/4 z-0 flex flex-col items-center justify-center">
|
||||
<p>{session.user.name || "No username"}</p>
|
||||
<form onSubmit={updateUser}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
placeholder="Edit Username"
|
||||
type="text"
|
||||
value={username}
|
||||
></Input>
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
</InputGroup>
|
||||
</form>
|
||||
<EditForm></EditForm>
|
||||
</main>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
const EditForm = () => {
|
||||
const { data: session } = useSession();
|
||||
|
||||
const updateUser = async ({ username }: { username: string }) => {
|
||||
try {
|
||||
const body = { username };
|
||||
await fetch("/api/username", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
session.user.name = username;
|
||||
await Router.push("/account");
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
}
|
||||
};
|
||||
|
||||
const { register, handleSubmit, control } = useForm<{ username: string }>({
|
||||
defaultValues: {
|
||||
username: session?.user.name,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(updateUser)}>
|
||||
<InputGroup>
|
||||
<Input placeholder="Edit Username" type="text" {...register("username")}></Input>
|
||||
<SubmitButton control={control}></SubmitButton>
|
||||
</InputGroup>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const SubmitButton = ({ control }: { control: Control<{ username: string }> }) => {
|
||||
const username = useWatch({ control, name: "username" });
|
||||
return (
|
||||
<Button disabled={!username} type="submit" value="Change">
|
||||
Submit
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -7,6 +7,7 @@ import CredentialsProvider from "next-auth/providers/credentials";
|
||||
import DiscordProvider from "next-auth/providers/discord";
|
||||
import EmailProvider from "next-auth/providers/email";
|
||||
import prisma from "src/lib/prismadb";
|
||||
import { generateUsername } from "unique-username-generator";
|
||||
|
||||
const providers: Provider[] = [];
|
||||
|
||||
@@ -90,6 +91,7 @@ export const authOptions: AuthOptions = {
|
||||
async session({ session, token }) {
|
||||
session.user.role = token.role;
|
||||
session.user.isNew = token.isNew;
|
||||
session.user.name = token.name;
|
||||
return session;
|
||||
},
|
||||
/**
|
||||
@@ -97,10 +99,11 @@ export const authOptions: AuthOptions = {
|
||||
* This let's use forward the role to the session object.
|
||||
*/
|
||||
async jwt({ token }) {
|
||||
const { isNew, role } = await prisma.user.findUnique({
|
||||
const { isNew, name, role } = await prisma.user.findUnique({
|
||||
where: { id: token.sub },
|
||||
select: { role: true, isNew: true },
|
||||
select: { name: true, role: true, isNew: true },
|
||||
});
|
||||
token.name = name;
|
||||
token.role = role;
|
||||
token.isNew = isNew;
|
||||
return token;
|
||||
@@ -110,7 +113,18 @@ export const authOptions: AuthOptions = {
|
||||
/**
|
||||
* Update the user's role after they have successfully signed in
|
||||
*/
|
||||
async signIn({ user, account }) {
|
||||
async signIn({ user, account, isNewUser }) {
|
||||
if (isNewUser && account.provider === "email") {
|
||||
await prisma.user.update({
|
||||
data: {
|
||||
name: generateUsername(),
|
||||
},
|
||||
where: {
|
||||
id: user.id,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Get the admin list for the user's auth type.
|
||||
const adminForAccountType = adminUserMap.get(account.provider);
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
import { withoutRole } from "src/lib/auth";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
/**
|
||||
* Updates the user's `name` field in the `User` table.
|
||||
*/
|
||||
const handler = withoutRole("banned", async (req, res, token) => {
|
||||
const { username } = req.body;
|
||||
const { name } = await prisma.user.update({
|
||||
where: {
|
||||
id: token.sub,
|
||||
},
|
||||
data: {
|
||||
name: username,
|
||||
},
|
||||
});
|
||||
res.json({ name });
|
||||
});
|
||||
|
||||
export default handler;
|
||||
@@ -1,20 +0,0 @@
|
||||
import { getSession } from "next-auth/react";
|
||||
import prisma from "src/lib/prismadb";
|
||||
|
||||
// POST /api/post
|
||||
// Required fields in body: title
|
||||
// Optional fields in body: content
|
||||
export default async function handle(req, res) {
|
||||
const { username } = req.body;
|
||||
|
||||
const session = await getSession({ req });
|
||||
const result = await prisma.user.update({
|
||||
where: {
|
||||
email: session.user.email,
|
||||
},
|
||||
data: {
|
||||
name: username,
|
||||
},
|
||||
});
|
||||
res.json({ name: result.name });
|
||||
}
|
||||
@@ -6,11 +6,12 @@ import Link from "next/link";
|
||||
import { useRouter } from "next/router";
|
||||
import { ClientSafeProvider, getProviders, signIn } from "next-auth/react";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa";
|
||||
import { AuthLayout } from "src/components/AuthLayout";
|
||||
import { Footer } from "src/components/Footer";
|
||||
import { Header } from "src/components/Header";
|
||||
import { RoleSelect } from "src/components/RoleSelect";
|
||||
import { Role, RoleSelect } from "src/components/RoleSelect";
|
||||
|
||||
export type SignInErrorTypes =
|
||||
| "Signin"
|
||||
@@ -60,15 +61,14 @@ function Signin({ providers }: SigninProps) {
|
||||
}
|
||||
}, [router]);
|
||||
|
||||
const signinWithEmail = (ev: React.FormEvent) => {
|
||||
ev.preventDefault();
|
||||
signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value });
|
||||
const signinWithEmail = (data: { email: string }) => {
|
||||
signIn(email.id, { callbackUrl: "/dashboard", email: data.email });
|
||||
};
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900";
|
||||
const buttonBgColor = colorMode === "light" ? "#2563eb" : "#2563eb";
|
||||
|
||||
const { register, handleSubmit } = useForm<{ email: string }>();
|
||||
return (
|
||||
<div className={bgColorClass}>
|
||||
<Head>
|
||||
@@ -79,7 +79,7 @@ function Signin({ providers }: SigninProps) {
|
||||
<Stack spacing="2">
|
||||
{credentials && <DebugSigninForm credentials={credentials} bgColorClass={bgColorClass} />}
|
||||
{email && (
|
||||
<form onSubmit={signinWithEmail}>
|
||||
<form onSubmit={handleSubmit(signinWithEmail)}>
|
||||
<Stack>
|
||||
<Input
|
||||
type="email"
|
||||
@@ -87,7 +87,7 @@ function Signin({ providers }: SigninProps) {
|
||||
variant="outline"
|
||||
size="lg"
|
||||
placeholder="Email Address"
|
||||
ref={emailEl}
|
||||
{...register("email")}
|
||||
/>
|
||||
<SigninButton data-cy="signin-email-button" leftIcon={<FaEnvelope />}>
|
||||
Continue with Email
|
||||
@@ -174,23 +174,35 @@ const SigninButton = (props: ButtonProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
interface DebugSigninFormData {
|
||||
username: string;
|
||||
role: Role;
|
||||
}
|
||||
|
||||
const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSafeProvider; bgColorClass: string }) => {
|
||||
const debugUsernameEl = useRef(null);
|
||||
const roleRef = useRef(null);
|
||||
function signinWithDebugCredentials(ev: React.FormEvent) {
|
||||
ev.preventDefault();
|
||||
const { register, handleSubmit } = useForm<DebugSigninFormData>({
|
||||
defaultValues: {
|
||||
role: "general",
|
||||
username: "dev",
|
||||
},
|
||||
});
|
||||
|
||||
function signinWithDebugCredentials(data: DebugSigninFormData) {
|
||||
signIn(credentials.id, {
|
||||
callbackUrl: "/dashboard",
|
||||
username: debugUsernameEl.current.value,
|
||||
role: roleRef.current.value,
|
||||
...data,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={signinWithDebugCredentials} className="border-2 border-orange-600 rounded-md p-4 relative">
|
||||
<form
|
||||
onSubmit={handleSubmit(signinWithDebugCredentials)}
|
||||
className="border-2 border-orange-600 rounded-md p-4 relative"
|
||||
>
|
||||
<span className={`text-orange-600 absolute -top-3 left-5 ${bgColorClass} px-1`}>For Debugging Only</span>
|
||||
<Stack>
|
||||
<Input variant="outline" size="lg" placeholder="Username" ref={debugUsernameEl} />
|
||||
<RoleSelect defaultValue={"general"} ref={roleRef}></RoleSelect>
|
||||
<Input variant="outline" size="lg" placeholder="Username" {...register("username")} />
|
||||
<RoleSelect {...register("role")}></RoleSelect>
|
||||
<SigninButton leftIcon={<FaBug />}>Continue with Debug User</SigninButton>
|
||||
</Stack>
|
||||
</form>
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { Box } from "@chakra-ui/react";
|
||||
import Head from "next/head";
|
||||
import { useRouter } from "next/router";
|
||||
import { useSession } from "next-auth/react";
|
||||
import { useTranslation } from "next-i18next";
|
||||
import { serverSideTranslations } from "next-i18next/serverSideTranslations";
|
||||
import { useEffect } from "react";
|
||||
import { CallToAction } from "src/components/CallToAction";
|
||||
import { Faq } from "src/components/Faq";
|
||||
@@ -10,6 +13,7 @@ import { getTransparentHeaderLayout } from "src/components/Layout";
|
||||
const Home = () => {
|
||||
const router = useRouter();
|
||||
const { status } = useSession();
|
||||
const { t } = useTranslation("index");
|
||||
useEffect(() => {
|
||||
if (status === "authenticated") {
|
||||
router.push("/dashboard");
|
||||
@@ -19,21 +23,24 @@ const Home = () => {
|
||||
return (
|
||||
<>
|
||||
<Head>
|
||||
<title>Open Assistant</title>
|
||||
<meta
|
||||
name="description"
|
||||
content="Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world."
|
||||
/>
|
||||
<title>{t("title")}</title>
|
||||
<meta name="description" content={t("description")} />
|
||||
</Head>
|
||||
<main className="oa-basic-theme">
|
||||
<Box as="main" className="oa-basic-theme">
|
||||
<Hero />
|
||||
<CallToAction />
|
||||
<Faq />
|
||||
</main>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
Home.getLayout = getTransparentHeaderLayout;
|
||||
|
||||
export const getStaticProps = async ({ locale }) => ({
|
||||
props: {
|
||||
...(await serverSideTranslations(locale, ["index", "common"])),
|
||||
},
|
||||
});
|
||||
|
||||
export default Home;
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# Page Tests
|
||||
|
||||
Put all page tests in this directory with the patter `MyPage.test.jsx`. We can't place them in `src/pages` due to how
|
||||
NextJS generates page routes.
|
||||
Reference in New Issue
Block a user