diff --git a/.github/workflows/build-postgres.yaml b/.github/workflows/build-postgres.yaml
new file mode 100644
index 00000000..2522a1d7
--- /dev/null
+++ b/.github/workflows/build-postgres.yaml
@@ -0,0 +1,17 @@
+name: Build OASST Postgres image
+
+on:
+ push:
+ branches:
+ - main
+ paths:
+ - docker/oasst-postgres/**
+
+jobs:
+ build-postgres:
+ uses: ./.github/workflows/docker-build.yaml
+ with:
+ image-name: oasst-postgres
+ context: ./docker/oasst-postgres
+ dockerfile: docker/oasst-postgres/Dockerfile
+ build-args: ""
\ No newline at end of file
diff --git a/.github/workflows/deploy-to-node.yaml b/.github/workflows/deploy-to-node.yaml
index f107d0af..5da5f59f 100644
--- a/.github/workflows/deploy-to-node.yaml
+++ b/.github/workflows/deploy-to-node.yaml
@@ -34,8 +34,13 @@ jobs:
WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }}
WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }}
S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }}
+ S3_REGION: ${{ secrets.S3_REGION }}
AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }}
AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }}
+ MAX_ACTIVE_TREES: ${{ vars.MAX_ACTIVE_TREES }}
+ MAX_TREE_DEPTH: ${{ vars.MAX_TREE_DEPTH }}
+ GOAL_TREE_SIZE: ${{ vars.GOAL_TREE_SIZE }}
+ SKIP_TOXICITY_CALCULATION: ${{ vars.SKIP_TOXICITY_CALCULATION }}
steps:
- name: Checkout
uses: actions/checkout@v2
diff --git a/ansible/deploy-to-node.yaml b/ansible/deploy-to-node.yaml
index 94746437..a89d969a 100644
--- a/ansible/deploy-to-node.yaml
+++ b/ansible/deploy-to-node.yaml
@@ -57,8 +57,9 @@
- name: Create postgres containers
community.docker.docker_container:
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
- image: postgres:15
+ image: ghcr.io/laion-ai/open-assistant/oasst-postgres
state: started
+ pull: true
recreate: "{{ (stack_name == 'dev') | bool }}"
restart_policy: always
network_mode: "oasst-{{ stack_name }}"
@@ -66,6 +67,13 @@
POSTGRES_USER: postgres
POSTGRES_PASSWORD: "{{ postgres_password }}"
POSTGRES_DB: postgres
+ S3_BUCKET_NAME:
+ "{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
+ AWS_ACCESS_KEY_ID:
+ "{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
+ AWS_SECRET_ACCESS_KEY:
+ "{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
+ AWS_DEFAULT_REGION: "{{ lookup('ansible.builtin.env', 'S3_REGION') }}"
volumes:
- "oasst-{{ stack_name }}-postgres-{{ item.name
}}:/var/lib/postgresql/data"
@@ -78,29 +86,6 @@
- name: backend
- name: web
- - name: Copy pgbackrest.conf to managed node
- ansible.builtin.copy:
- src: ./pgbackrest.conf
- dest: "./{{ stack_name }}/pgbackrest.conf"
- mode: 0644
-
- - name: Create pgbackrest container
- community.docker.docker_container:
- name: "oasst-{{ stack_name }}-pgbackrest"
- image: woblerr/pgbackrest:2.43
- state: "{{ 'stopped' if stack_name == 'production' else 'absent' }}"
- network_mode: "oasst-{{ stack_name }}"
- volumes:
- - "./{{ stack_name }}/pgbackrest.conf:/etc/pgbackrest/pgbackrest.conf"
- - "oasst-{{ stack_name }}-postgres-backend:/var/lib/postgresql/data"
- env:
- PGBACKREST_REPO1_S3_BUCKET:
- "{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}"
- PGBACKREST_REPO1_S3_KEY:
- "{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}"
- PGBACKREST_REPO1_S3_KEY_SECRET:
- "{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}"
-
- name: Run the oasst oasst-backend
community.docker.docker_container:
name: "oasst-{{ stack_name }}-backend"
@@ -122,8 +107,18 @@
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
DEBUG_SKIP_TOXICITY_CALCULATION:
- "{{ 'true' if stack_name == 'dev' else 'false' }}"
+ "{{ lookup('ansible.builtin.env', 'SKIP_TOXICITY_CALCULATION') |
+ default('true', true) }}"
OFFICIAL_WEB_API_KEY: "{{ web_api_key }}"
+ TREE_MANAGER__MAX_ACTIVE_TREES:
+ "{{ lookup('ansible.builtin.env', 'MAX_ACTIVE_TREES') |
+ default('10', true) }}"
+ TREE_MANAGER__MAX_TREE_DEPTH:
+ "{{ lookup('ansible.builtin.env', 'MAX_TREE_DEPTH') | default('5',
+ true) }}"
+ TREE_MANAGER__GOAL_TREE_SIZE:
+ "{{ lookup('ansible.builtin.env', 'GOAL_TREE_SIZE') | default('15',
+ true) }}"
ports:
- "{{ backend_port }}:8080"
diff --git a/ansible/pgbackrest.conf b/ansible/pgbackrest.conf
index 036826d3..147ff8c1 100644
--- a/ansible/pgbackrest.conf
+++ b/ansible/pgbackrest.conf
@@ -2,7 +2,7 @@
pg1-path=/var/lib/postgresql/data
[global]
-repo1-retention-full=3 # keep last 3 backups
+repo1-retention-full=3
repo1-type=s3
repo1-path=/oasst-prod
repo1-s3-region=us-east-1
diff --git a/backend/alembic/versions/2023_01_19_2153-7f0a28a156f4_switch_to_timestamp_with_tz.py b/backend/alembic/versions/2023_01_19_2153-7f0a28a156f4_switch_to_timestamp_with_tz.py
new file mode 100644
index 00000000..d3096b2f
--- /dev/null
+++ b/backend/alembic/versions/2023_01_19_2153-7f0a28a156f4_switch_to_timestamp_with_tz.py
@@ -0,0 +1,47 @@
+"""switch to timestamp with tz
+
+Revision ID: 7f0a28a156f4
+Revises: 0964ac95170d
+Create Date: 2023-01-19 21:53:01.107137
+
+"""
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "7f0a28a156f4"
+down_revision = "0964ac95170d"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column(table_name="user_stats", column_name="modified_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="user_stats", column_name="base_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="journal_integration", column_name="last_run", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="message_embedding", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="message_reaction", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="message_toxicity", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="message", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="task", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="task", column_name="expiry_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="text_labels", column_name="created_date", type_=sa.DateTime(timezone=True))
+ op.alter_column(table_name="user", column_name="created_date", type_=sa.DateTime(timezone=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.alter_column(table_name="user_stats", column_name="modified_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="user_stats", column_name="base_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="journal_integration", column_name="last_run", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="message_embedding", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="message_reaction", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="message_toxicity", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="message", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="task", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="task", column_name="expiry_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="text_labels", column_name="created_date", type_=sa.DateTime(timezone=False))
+ op.alter_column(table_name="user", column_name="created_date", type_=sa.DateTime(timezone=False))
+ # ### end Alembic commands ###
diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py
index e8d3078e..59c5efda 100644
--- a/backend/oasst_backend/api/v1/admin.py
+++ b/backend/oasst_backend/api/v1/admin.py
@@ -1,7 +1,17 @@
+from datetime import datetime
+from uuid import UUID
+
import pydantic
from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
+from oasst_backend.config import Settings, settings
+from oasst_backend.models import ApiClient, User
+from oasst_backend.prompt_repository import PromptRepository
+from oasst_backend.tree_manager import TreeManager
+from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
+from oasst_shared.schemas.protocol import SystemStats
+from oasst_shared.utils import ScopeTimer, unaware_to_utc
router = APIRouter()
@@ -13,7 +23,7 @@ class CreateApiClientRequest(pydantic.BaseModel):
admin_email: str | None = None
-@router.post("/api_client")
+@router.post("/api_client", response_model=str)
async def create_api_client(
request: CreateApiClientRequest,
root_token: str = Depends(deps.get_root_token),
@@ -29,3 +39,125 @@ async def create_api_client(
)
logger.info(f"Created api_client with key {api_client.api_key}")
return api_client.api_key
+
+
+@router.get("/backend_settings/full", response_model=Settings)
+async def get_backend_settings_full(api_client: ApiClient = Depends(deps.get_trusted_api_client)) -> Settings:
+ logger.info(
+ f"Backend settings requested by trusted api_client {api_client.id} (admin_email: {api_client.admin_email}, frontend_type: {api_client.frontend_type})"
+ )
+ return settings
+
+
+class PublicSettings(pydantic.BaseModel):
+ """Subset of backend settings which can be retrieved by untrusted API clients."""
+
+ PROJECT_NAME: str
+ API_V1_STR: str
+ DEBUG_USE_SEED_DATA: bool
+ DEBUG_ALLOW_SELF_LABELING: bool
+ DEBUG_SKIP_EMBEDDING_COMPUTATION: bool
+ DEBUG_SKIP_TOXICITY_CALCULATION: bool
+ DEBUG_DATABASE_ECHO: bool
+ USER_STATS_INTERVAL_DAY: int
+ USER_STATS_INTERVAL_WEEK: int
+ USER_STATS_INTERVAL_MONTH: int
+ USER_STATS_INTERVAL_TOTAL: int
+
+
+@router.get("/backend_settings/public", response_model=PublicSettings)
+async def get_backend_settings_public(api_client: ApiClient = Depends(deps.get_api_client)) -> PublicSettings:
+ return PublicSettings(**settings.dict())
+
+
+class PurgeResultModel(pydantic.BaseModel):
+ before: SystemStats
+ after: SystemStats
+ preview: bool
+ duration: float
+
+
+@router.post("/purge_user/{user_id}", response_model=PurgeResultModel)
+async def purge_user(
+ user_id: UUID,
+ preview: bool = True,
+ ban: bool = True,
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+) -> str:
+ assert api_client.trusted
+
+ @managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
+ def purge_tx(session: deps.Session) -> tuple[User, SystemStats, SystemStats]:
+ pr = PromptRepository(session, api_client)
+
+ stats_before = pr.get_stats()
+
+ user = pr.user_repository.get_user(user_id)
+ tm = TreeManager(session, pr)
+ tm.purge_user(user_id=user_id, ban=ban)
+
+ session.expunge(user)
+ return user, stats_before, pr.get_stats()
+
+ timer = ScopeTimer()
+ user, before, after = purge_tx()
+ timer.stop()
+
+ if preview:
+ logger.info(
+ f"PURGE USER PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
+ )
+ else:
+ logger.warning(
+ f"PURGE USER: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
+ )
+
+ logger.info(f"{before=}; {after=}")
+ return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
+
+
+@router.post("/purge_user/{user_id}/messages", response_model=PurgeResultModel)
+async def purge_user_messages(
+ user_id: UUID,
+ purge_initial_prompts: bool = False,
+ min_date: datetime = None,
+ max_date: datetime = None,
+ preview: bool = True,
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+) -> str:
+ assert api_client.trusted
+
+ min_date = unaware_to_utc(min_date)
+ max_date = unaware_to_utc(max_date)
+
+ @managed_tx_function(CommitMode.ROLLBACK if preview else CommitMode.COMMIT)
+ def purge_user_messages_tx(session: deps.Session):
+ pr = PromptRepository(session, api_client)
+
+ stats_before = pr.get_stats()
+
+ user = pr.user_repository.get_user(user_id)
+
+ tm = TreeManager(session, pr)
+ tm.purge_user_messages(
+ user_id, purge_initial_prompts=purge_initial_prompts, min_date=min_date, max_date=max_date
+ )
+
+ session.expunge(user)
+ return user, stats_before, pr.get_stats()
+
+ timer = ScopeTimer()
+ user, before, after = purge_user_messages_tx()
+ timer.stop()
+
+ if preview:
+ logger.info(
+ f"PURGE USER MESSAGES PREVIEW: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
+ )
+ else:
+ logger.warning(
+ f"PURGE USER MESSAGES: '{user.display_name}' (id: {str(user_id)}; username: '{user.username}'; auth-method: '{user.auth_method}')"
+ )
+
+ logger.info(f"{before=}; {after=}")
+ return PurgeResultModel(before=before, after=after, preview=preview, duration=timer.elapsed)
diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py
index f149bebb..dffb2824 100644
--- a/backend/oasst_backend/api/v1/frontend_messages.py
+++ b/backend/oasst_backend/api/v1/frontend_messages.py
@@ -45,7 +45,7 @@ def get_tree_by_frontend_id(
"""
pr = PromptRepository(db, api_client)
message = pr.fetch_message_by_frontend_message_id(message_id)
- tree = pr.fetch_message_tree(message.message_tree_id)
+ tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
return utils.prepare_tree(tree, message.message_tree_id)
diff --git a/backend/oasst_backend/api/v1/leaderboards.py b/backend/oasst_backend/api/v1/leaderboards.py
index 213855a1..27366475 100644
--- a/backend/oasst_backend/api/v1/leaderboards.py
+++ b/backend/oasst_backend/api/v1/leaderboards.py
@@ -6,6 +6,7 @@ from oasst_backend.models import ApiClient
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session
+from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@@ -19,3 +20,22 @@ def get_leaderboard(
) -> LeaderboardStats:
usr = UserStatsRepository(db)
return usr.get_leaderboard(time_frame, limit=max_count)
+
+
+@router.post("/update/{time_frame}", response_model=None, status_code=HTTP_204_NO_CONTENT)
+def update_leaderboard_time_frame(
+ time_frame: UserStatsTimeFrame,
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+ db: Session = Depends(deps.get_db),
+) -> LeaderboardStats:
+ usr = UserStatsRepository(db)
+ return usr.update_stats(time_frame=time_frame)
+
+
+@router.post("/update", response_model=None, status_code=HTTP_204_NO_CONTENT)
+def update_leaderboards_all(
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+ db: Session = Depends(deps.get_db),
+) -> LeaderboardStats:
+ usr = UserStatsRepository(db)
+ return usr.update_all_time_frames()
diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py
index 6229e20c..409240cb 100644
--- a/backend/oasst_backend/api/v1/messages.py
+++ b/backend/oasst_backend/api/v1/messages.py
@@ -7,6 +7,7 @@ from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
+from oasst_shared.utils import unaware_to_utc
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -29,6 +30,9 @@ def query_messages(
"""
Query messages.
"""
+ start_date = unaware_to_utc(start_date)
+ end_date = unaware_to_utc(end_date)
+
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
username=username,
@@ -78,7 +82,7 @@ def get_tree(
"""
pr = PromptRepository(db, api_client)
message = pr.fetch_message(message_id)
- tree = pr.fetch_message_tree(message.message_tree_id)
+ tree = pr.fetch_message_tree(message.message_tree_id, reviewed=False)
return utils.prepare_tree(tree, message.message_tree_id)
diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py
index 1aaffb1b..f6bbbdf3 100644
--- a/backend/oasst_backend/api/v1/stats.py
+++ b/backend/oasst_backend/api/v1/stats.py
@@ -2,6 +2,7 @@ from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
+from oasst_backend.tree_manager import TreeManager, TreeManagerStats, TreeMessageCountStats
from oasst_shared.schemas import protocol
from sqlmodel import Session
@@ -15,3 +16,34 @@ def get_message_stats(
):
pr = PromptRepository(db, api_client)
return pr.get_stats()
+
+
+@router.get("/tree_manager/state_counts", response_model=dict[str, int])
+def get_tree_manager__state_counts(
+ db: Session = Depends(deps.get_db),
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+):
+ pr = PromptRepository(db, api_client)
+ tm = TreeManager(db, pr)
+ return tm.tree_counts_by_state()
+
+
+@router.get("/tree_manager/message_counts", response_model=list[TreeMessageCountStats])
+def get_tree_manager__message_counts(
+ only_active: bool = True,
+ db: Session = Depends(deps.get_db),
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+):
+ pr = PromptRepository(db, api_client)
+ tm = TreeManager(db, pr)
+ return tm.tree_message_count_stats(only_active=only_active)
+
+
+@router.get("/tree_manager", response_model=TreeManagerStats)
+def get_tree_manager__stats(
+ db: Session = Depends(deps.get_db),
+ api_client: ApiClient = Depends(deps.get_trusted_api_client),
+):
+ pr = PromptRepository(db, api_client)
+ tm = TreeManager(db, pr)
+ return tm.stats()
diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py
index c65500fb..9e9118da 100644
--- a/backend/oasst_backend/api/v1/tasks.py
+++ b/backend/oasst_backend/api/v1/tasks.py
@@ -36,6 +36,8 @@ def request_task(
try:
pr = PromptRepository(db, api_client, client_user=request.user)
+ pr.ensure_user_is_enabled()
+
tm = TreeManager(db, pr)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
diff --git a/backend/oasst_backend/models/journal.py b/backend/oasst_backend/models/journal.py
index b5000add..46a72bdd 100644
--- a/backend/oasst_backend/models/journal.py
+++ b/backend/oasst_backend/models/journal.py
@@ -50,6 +50,6 @@ class JournalIntegration(SQLModel, table=True):
)
description: str = Field(max_length=512, primary_key=True)
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
- last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
+ last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
last_error: Optional[str] = Field(nullable=True)
next_run: Optional[datetime] = Field(nullable=True)
diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py
index 7c8b9f13..b03c8534 100644
--- a/backend/oasst_backend/models/message.py
+++ b/backend/oasst_backend/models/message.py
@@ -30,7 +30,9 @@ class Message(SQLModel, table=True):
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
+ sa_column=sa.Column(
+ sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
+ )
)
payload_type: str = Field(nullable=False, max_length=200)
payload: Optional[PayloadContainer] = Field(
diff --git a/backend/oasst_backend/models/message_embedding.py b/backend/oasst_backend/models/message_embedding.py
index 74da5004..e75f5398 100644
--- a/backend/oasst_backend/models/message_embedding.py
+++ b/backend/oasst_backend/models/message_embedding.py
@@ -17,5 +17,5 @@ class MessageEmbedding(SQLModel, table=True):
# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
+ sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
diff --git a/backend/oasst_backend/models/message_reaction.py b/backend/oasst_backend/models/message_reaction.py
index 74e21a61..4c50143e 100644
--- a/backend/oasst_backend/models/message_reaction.py
+++ b/backend/oasst_backend/models/message_reaction.py
@@ -19,7 +19,9 @@ class MessageReaction(SQLModel, table=True):
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True)
+ sa_column=sa.Column(
+ sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
+ )
)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
diff --git a/backend/oasst_backend/models/message_toxicity.py b/backend/oasst_backend/models/message_toxicity.py
index 8a78e2dc..f8eb787b 100644
--- a/backend/oasst_backend/models/message_toxicity.py
+++ b/backend/oasst_backend/models/message_toxicity.py
@@ -20,5 +20,5 @@ class MessageToxicity(SQLModel, table=True):
# In the case that the Message Embedding is created afterwards
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
+ sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py
index 5d0d7e73..a59f689e 100644
--- a/backend/oasst_backend/models/task.py
+++ b/backend/oasst_backend/models/task.py
@@ -20,9 +20,9 @@ class Task(SQLModel, table=True):
),
)
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
+ sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()),
)
- expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
+ expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py
index 34831d6b..1d238ef2 100644
--- a/backend/oasst_backend/models/text_labels.py
+++ b/backend/oasst_backend/models/text_labels.py
@@ -17,7 +17,9 @@ class TextLabels(SQLModel, table=True):
)
user_id: UUID = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False))
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp(), index=True),
+ sa_column=sa.Column(
+ sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp(), index=True
+ ),
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py
index 7ce914c6..0fb36c22 100644
--- a/backend/oasst_backend/models/user.py
+++ b/backend/oasst_backend/models/user.py
@@ -21,7 +21,7 @@ class User(SQLModel, table=True):
auth_method: str = Field(nullable=False, max_length=128, default="local")
display_name: str = Field(nullable=False, max_length=256)
created_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
+ sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
api_client_id: UUID = Field(foreign_key="api_client.id")
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
diff --git a/backend/oasst_backend/models/user_stats.py b/backend/oasst_backend/models/user_stats.py
index 5ba9dcdb..e8f5b450 100644
--- a/backend/oasst_backend/models/user_stats.py
+++ b/backend/oasst_backend/models/user_stats.py
@@ -26,11 +26,11 @@ class UserStats(SQLModel, table=True):
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
- base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
+ base_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
leader_score: int = 0
modified_date: Optional[datetime] = Field(
- sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
+ sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
rank: int = Field(nullable=True)
diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py
index 7f51ea19..35e1ece4 100644
--- a/backend/oasst_backend/prompt_repository.py
+++ b/backend/oasst_backend/prompt_repository.py
@@ -28,8 +28,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
-from sqlalchemy import update
-from sqlmodel import Session, func
+from sqlmodel import Session, func, not_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
@@ -53,6 +52,13 @@ class PromptRepository:
)
self.journal = JournalWriter(db, api_client, self.user)
+ def ensure_user_is_enabled(self):
+ if self.user is None or self.user_id is None:
+ raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
+
+ if self.user.deleted or not self.user.enabled:
+ raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
+
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
message: Message = (
@@ -146,6 +152,8 @@ class PromptRepository:
review_result: bool = False,
check_tree_state: bool = True,
) -> Message:
+ self.ensure_user_is_enabled()
+
validate_frontend_message_id(frontend_message_id)
validate_frontend_message_id(user_frontend_message_id)
@@ -354,8 +362,7 @@ class PromptRepository:
@managed_tx_method(CommitMode.FLUSH)
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
- if self.user_id is None:
- raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
+ self.ensure_user_is_enabled()
container = PayloadContainer(payload=payload)
reaction = MessageReaction(
@@ -499,10 +506,14 @@ class PromptRepository:
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages
- def fetch_message_tree(self, message_tree_id: UUID, reviewed: bool = True):
+ def fetch_message_tree(
+ self, message_tree_id: UUID, reviewed: bool = True, include_deleted: bool = False
+ ) -> list[Message]:
qry = self.db.query(Message).filter(Message.message_tree_id == message_tree_id)
if reviewed:
qry = qry.filter(Message.review_result)
+ if not include_deleted:
+ qry = qry.filter(not_(Message.deleted))
return qry.all()
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
@@ -702,6 +713,21 @@ class PromptRepository:
return messages.all()
+ def update_children_counts(self, message_tree_id: UUID):
+ sql_update_children_count = """
+UPDATE message SET children_count = cc.children_count
+FROM (
+ SELECT m.id, count(c.id) - COALESCE(SUM(c.deleted::int), 0) AS children_count
+ FROM message m
+ LEFT JOIN message c ON m.id = c.parent_id
+ WHERE m.message_tree_id = :message_tree_id
+ GROUP BY m.id
+) AS cc
+WHERE message.id = cc.id;
+"""
+ r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id})
+ logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.")
+
@managed_tx_method(CommitMode.COMMIT)
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
"""
diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py
index 225b0146..f4338048 100644
--- a/backend/oasst_backend/tree_manager.py
+++ b/backend/oasst_backend/tree_manager.py
@@ -1,4 +1,5 @@
import random
+from datetime import datetime
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
@@ -9,14 +10,14 @@ import pydantic
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
from oasst_backend.config import TreeManagerConfiguration, settings
-from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state
+from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, User, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
+from oasst_backend.utils.ranking import ranked_pairs
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
-from sqlalchemy.sql import text
-from sqlmodel import Session, func, not_
+from sqlmodel import Session, func, not_, text, update
class TaskType(Enum):
@@ -68,6 +69,25 @@ class IncompleteRankingsRow(pydantic.BaseModel):
orm_mode = True
+class TreeMessageCountStats(pydantic.BaseModel):
+ message_tree_id: UUID
+ state: str
+ depth: int
+ oldest: datetime
+ youngest: datetime
+ count: int
+ goal_tree_size: int
+
+ @property
+ def completed(self) -> int:
+ return self.count / self.goal_tree_size
+
+
+class TreeManagerStats(pydantic.BaseModel):
+ state_counts: dict[str, int]
+ message_counts: list[TreeMessageCountStats]
+
+
class TreeManager:
_all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel))
@@ -130,7 +150,7 @@ class TreeManager:
def _determine_task_availability_internal(
self,
num_active_trees: int,
- extensible_parents: list[ExtendibleParentRow],
+ extendible_parents: list[ExtendibleParentRow],
prompts_need_review: list[Message],
replies_need_review: list[Message],
incomplete_rankings: list[IncompleteRankingsRow],
@@ -141,17 +161,17 @@ class TreeManager:
task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts
task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
- list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
+ list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
)
task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len(
- list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
+ list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
)
task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review)
task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len(
list(filter(lambda m: m.role == "assistant", replies_need_review))
)
- task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len(
+ task_count_by_type[protocol_schema.TaskRequestType.label_prompter_reply] = len(
list(filter(lambda m: m.role == "prompter", replies_need_review))
)
@@ -171,15 +191,17 @@ class TreeManager:
return task_count_by_type
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
+ self.pr.ensure_user_is_enabled()
+
num_active_trees = self.query_num_active_trees()
- extensible_parents = self.query_extendible_parents()
+ extendible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
- extensible_parents=extensible_parents,
+ extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
incomplete_rankings=incomplete_rankings,
@@ -191,10 +213,12 @@ class TreeManager:
logger.debug("TreeManager.next_task()")
+ self.pr.ensure_user_is_enabled()
+
num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
- extensible_parents = self.query_extendible_parents()
+ extendible_parents = self.query_extendible_parents()
incomplete_rankings = self.query_incomplete_rankings()
if not self.cfg.rank_prompter_replies:
@@ -224,7 +248,7 @@ class TreeManager:
else:
task_count_by_type = self._determine_task_availability_internal(
num_active_trees=num_active_trees,
- extensible_parents=extensible_parents,
+ extendible_parents=extendible_parents,
prompts_need_review=prompts_need_review,
replies_need_review=replies_need_review,
incomplete_rankings=incomplete_rankings,
@@ -266,7 +290,7 @@ class TreeManager:
ranking_parent_id = random.choice(incomplete_rankings).parent_id
messages = self.pr.fetch_message_conversation(ranking_parent_id)
- assert len(messages) > 1 and messages[-1].id == ranking_parent_id
+ assert len(messages) > 0 and messages[-1].id == ranking_parent_id
ranking_parent = messages[-1]
assert not ranking_parent.deleted and ranking_parent.review_result
conversation = prepare_conversation(messages)
@@ -356,12 +380,12 @@ class TreeManager:
case TaskType.REPLY:
# select a tree with missing replies
if task_role == TaskRole.PROMPTER:
- extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents))
+ extendible_parents = list(filter(lambda x: x.parent_role == "assistant", extendible_parents))
elif task_role == TaskRole.ASSISTANT:
- extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents))
+ extendible_parents = list(filter(lambda x: x.parent_role == "prompter", extendible_parents))
- if len(extensible_parents) > 0:
- random_parent = random.choice(extensible_parents)
+ if len(extendible_parents) > 0:
+ random_parent = random.choice(extendible_parents)
# fetch random conversation to extend
logger.debug(f"selected {random_parent=}")
@@ -424,6 +448,7 @@ class TreeManager:
@async_managed_tx_method(CommitMode.COMMIT)
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
+ pr.ensure_user_is_enabled()
match type(interaction):
case protocol_schema.TextReplyToMessage:
logger.info(
@@ -488,7 +513,8 @@ class TreeManager:
_, task = pr.store_ranking(interaction)
- self.check_condition_for_scoring_state(task.message_tree_id)
+ ok, rankings_by_message = self.check_condition_for_scoring_state(task.message_tree_id)
+ self.update_message_ranks(task.message_tree_id, rankings_by_message)
case protocol_schema.TextLabels:
logger.info(
@@ -551,7 +577,6 @@ class TreeManager:
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
- @managed_tx_method(CommitMode.COMMIT)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
@@ -569,7 +594,6 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.GROWING)
return True
- @managed_tx_method(CommitMode.COMMIT)
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
@@ -587,22 +611,54 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.RANKING)
return True
- def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool:
+ def check_condition_for_scoring_state(
+ self, message_tree_id: UUID
+ ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]:
logger.debug(f"check_condition_for_scoring_state({message_tree_id=})")
- mts: MessageTreeState
- mts = self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
+
+ mts = self.pr.fetch_tree_state(message_tree_id)
if not mts.active or mts.state != message_tree_state.State.RANKING:
logger.debug(f"False {mts.active=}, {mts.state=}")
- return False
+ return False, None
ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant"
rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter)
for parent_msg_id, ranking in rankings_by_message.items():
if len(ranking) < self.cfg.num_required_rankings:
logger.debug(f"False {parent_msg_id=} {len(ranking)=}")
- return False
+ return False, None
self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING)
+ return True, rankings_by_message
+
+ def update_message_ranks(self, message_tree_id: UUID, rankings_by_message: Dict[int, int]) -> bool:
+
+ mts = self.pr.fetch_tree_state(message_tree_id)
+ # check state, allow retry if in SCORING_FAILED state
+ if mts.state not in (message_tree_state.State.READY_FOR_SCORING, message_tree_state.State.SCORING_FAILED):
+ logger.debug(f"False {mts.active=}, {mts.state=}")
+ return False
+
+ try:
+ for rankings in rankings_by_message.values():
+ sorted_messages = []
+ for msg_reaction in rankings:
+ sorted_messages.append(msg_reaction.payload.payload.ranked_message_ids)
+ logger.debug(f"SORTED MESSAGE {sorted_messages}")
+ consensus = ranked_pairs(sorted_messages)
+ logger.debug(f"CONSENSUS: {consensus}\n\n")
+ for rank, message_id in enumerate(consensus):
+ # set rank for each message_id for Message rows
+ msg = self.pr.fetch_message(message_id=message_id, fail_if_missing=True)
+ msg.rank = rank
+ self.db.add(msg)
+
+ except Exception:
+ logger.exception(f"update_message_ranks({message_tree_id=}) failed")
+ self._enter_state(mts, message_tree_state.State.SCORING_FAILED)
+ return False
+
+ self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT)
return True
def _calculate_acceptance(self, labels: list[TextLabels]):
@@ -618,7 +674,7 @@ class TreeManager:
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
- .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
+ .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
@@ -643,7 +699,7 @@ class TreeManager:
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
- .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
+ .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.GROWING,
@@ -664,7 +720,7 @@ class TreeManager:
SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count,
COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings
FROM message_tree_state mts
- LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
+ INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
@@ -690,15 +746,15 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
-- find all extendible parent nodes
SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count
FROM message_tree_state mts
- LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
+ INNER JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree
LEFT JOIN message c ON m.id = c.parent_id -- child nodes
WHERE mts.active -- only consider active trees
AND mts.state = :growing_state -- message tree must be growing
AND NOT m.deleted -- ignore deleted messages as parents
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
- AND NOT c.deleted -- don't count deleted children
- AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review
+ AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
+ AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""
@@ -708,7 +764,10 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
r = self.db.execute(
text(self._sql_find_extendible_parents),
- {"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply},
+ {
+ "growing_state": message_tree_state.State.GROWING,
+ "num_reviews_reply": self.cfg.num_reviews_reply,
+ },
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
@@ -717,8 +776,8 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
SELECT m.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size
FROM (
SELECT DISTINCT message_tree_id FROM ({_sql_find_extendible_parents}) extendible_parents
- ) trees LEFT JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
- LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
+ ) trees INNER JOIN message_tree_state mts ON trees.message_tree_id = mts.message_tree_id
+ INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE NOT m.deleted
AND (
m.parent_id IS NOT NULL AND (m.review_result OR m.review_count < :num_reviews_reply) -- children
@@ -766,7 +825,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size
"""Find all initial prompt messages that have no associated message tree state"""
qry_missing_tree_states = (
self.db.query(Message.id)
- .join(MessageTreeState, isouter=True)
+ .outerjoin(MessageTreeState, Message.message_tree_id == MessageTreeState.message_tree_id)
.filter(
Message.parent_id.is_(None),
Message.message_tree_id == Message.id,
@@ -783,7 +842,7 @@ SELECT p.parent_id, mr.* FROM
-- find parents with > 1 children
SELECT m.parent_id, m.message_tree_id, COUNT(m.id) children_count
FROM message_tree_state mts
- LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id
+ INNER JOIN message m ON mts.message_tree_id = m.message_tree_id
WHERE m.review_result -- must be reviewed
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
@@ -792,8 +851,8 @@ SELECT p.parent_id, mr.* FROM
GROUP BY m.parent_id, m.message_tree_id
HAVING COUNT(m.id) > 1
) as p
-LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
-LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
+INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload')
+INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload'
"""
def query_tree_ranking_results(
@@ -832,7 +891,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
state = message_tree_state.State.INITIAL_PROMPT_REVIEW
if tree_size > 1:
state = message_tree_state.State.GROWING
- logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
+ logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
self._insert_default_state(id, state=state)
def query_num_active_trees(self) -> int:
@@ -885,6 +944,194 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
active=True,
)
+ def tree_counts_by_state(self) -> dict[str, int]:
+ qry = self.db.query(
+ MessageTreeState.state, func.count(MessageTreeState.message_tree_id).label("count")
+ ).group_by(MessageTreeState.state)
+ return {x["state"]: x["count"] for x in qry}
+
+ def tree_message_count_stats(self, only_active: bool = True) -> list[TreeMessageCountStats]:
+ qry = (
+ self.db.query(
+ MessageTreeState.message_tree_id,
+ func.max(Message.depth).label("depth"),
+ func.min(Message.created_date).label("oldest"),
+ func.max(Message.created_date).label("youngest"),
+ func.count(Message.id).label("count"),
+ MessageTreeState.goal_tree_size,
+ MessageTreeState.state,
+ )
+ .select_from(MessageTreeState)
+ .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
+ .filter(not_(Message.deleted))
+ .group_by(MessageTreeState.message_tree_id)
+ )
+
+ if only_active:
+ qry = qry.filter(MessageTreeState.active)
+
+ return [TreeMessageCountStats(**x) for x in qry]
+
+ def stats(self) -> TreeManagerStats:
+ return TreeManagerStats(
+ state_counts=self.tree_counts_by_state(),
+ message_counts=self.tree_message_count_stats(only_active=True),
+ )
+
+ def get_user_messages_by_tree(
+ self,
+ user_id: UUID,
+ min_date: datetime = None,
+ max_date: datetime = None,
+ ) -> Tuple[dict[UUID, list[Message]], list[Message]]:
+ """Returns a dict with replies by tree (excluding initial prompts) and list of initial prompts
+ associated with user_id."""
+
+ # query all messages of the user
+ qry = self.db.query(Message).filter(Message.user_id == user_id)
+ if min_date:
+ qry = qry.filter(Message.created_date >= min_date)
+ if max_date:
+ qry = qry.filter(Message.created_date <= max_date)
+
+ prompts: list[Message] = []
+ replies_by_tree: dict[UUID, list[Message]] = {}
+
+ # walk over result set and distinguish between initial prompts and replies
+ for m in qry:
+ m: Message
+
+ if m.message_tree_id == m.id:
+ prompts.append(m)
+ else:
+ message_list = replies_by_tree.get(m.message_tree_id)
+ if message_list is None:
+ message_list = [m]
+ replies_by_tree[m.message_tree_id] = message_list
+ else:
+ message_list.append(m)
+
+ return replies_by_tree, prompts
+
+ def _purge_message_internal(self, message_id: UUID) -> None:
+ """This internal function deletes a single message. It does not take care of
+ descendants, children_count in parent etc."""
+
+ sql_purge_message = """
+DELETE FROM journal j USING message m WHERE j.message_id = :message_id;
+DELETE FROM message_embedding e WHERE e.message_id = :message_id;
+DELETE FROM message_toxicity t WHERE t.message_id = :message_id;
+DELETE FROM text_labels l WHERE l.message_id = :message_id;
+-- delete all ranking results that contain message
+DELETE FROM message_reaction r WHERE r.payload_type = 'RankingReactionPayload' AND r.task_id IN (
+ SELECT t.id FROM message m
+ JOIN task t ON m.parent_id = t.parent_message_id
+ WHERE m.id = :message_id);
+-- delete task which inserted message
+DELETE FROM task t using message m WHERE t.id = m.task_id AND m.id = :message_id;
+DELETE FROM task t WHERE t.parent_message_id = :message_id;
+DELETE FROM message WHERE id = :message_id;
+"""
+ r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
+ logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")
+
+ def purge_message_tree(self, message_tree_id: UUID) -> None:
+ sql_purge_message_tree = """
+DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
+DELETE FROM message_embedding e USING message m WHERE e.message_id = m.Id AND m.message_tree_id = :message_tree_id;
+DELETE FROM message_toxicity t USING message m WHERE t.message_id = m.Id AND m.message_tree_id = :message_tree_id;
+DELETE FROM text_labels l USING message m WHERE l.message_id = m.Id AND m.message_tree_id = :message_tree_id;
+DELETE FROM message_reaction r USING task t WHERE r.task_id = t.id AND t.message_tree_id = :message_tree_id;
+DELETE FROM task t WHERE t.message_tree_id = :message_tree_id;
+DELETE FROM message_tree_state WHERE message_tree_id = :message_tree_id;
+DELETE FROM message WHERE message_tree_id = :message_tree_id;
+"""
+ r = self.db.execute(text(sql_purge_message_tree), {"message_tree_id": message_tree_id})
+ logger.debug(f"purge_message_tree({message_tree_id=}) {r.rowcount} rows.")
+
+ @managed_tx_method(CommitMode.FLUSH)
+ def purge_user_messages(
+ self,
+ user_id: UUID,
+ purge_initial_prompts: bool = True,
+ min_date: datetime = None,
+ max_date: datetime = None,
+ ):
+
+ # find all affected message trees
+ replies_by_tree, prompts = self.get_user_messages_by_tree(user_id, min_date, max_date)
+ total_messages = sum(len(x) for x in replies_by_tree.values())
+ logger.debug(f"found: {len(replies_by_tree)} trees; {len(prompts)} prompts; {total_messages} messages;")
+
+ # remove all trees based on inital prompts of the user
+ if purge_initial_prompts:
+ for p in prompts:
+ self.purge_message_tree(p.message_tree_id)
+ if p.message_tree_id in replies_by_tree:
+ del replies_by_tree[p.message_tree_id]
+
+ # patch all affected message trees
+ for tree_id, replies in replies_by_tree.items():
+ bad_parent_ids = set(m.id for m in replies)
+ logger.debug(f"patching tree {tree_id=}, {bad_parent_ids=}")
+
+ tree_messages = self.pr.fetch_message_tree(tree_id, reviewed=False, include_deleted=True)
+ logger.debug(f"{tree_id=}, {len(bad_parent_ids)=}, {len(tree_messages)=}")
+ by_id = {m.id: m for m in tree_messages}
+
+ def ancestor_ids(msg: Message) -> list[UUID]:
+ t = []
+ while msg.parent_id is not None:
+ msg = by_id[msg.parent_id]
+ t.append(msg.id)
+ return t
+
+ def is_descendant_of_deleted(m: Message) -> bool:
+ if m.id in bad_parent_ids:
+ return True
+ ancestors = ancestor_ids(m)
+ if any(a in bad_parent_ids for a in ancestors):
+ return True
+ return False
+
+ # start with deepest messages first
+ tree_messages.sort(key=lambda x: x.depth, reverse=True)
+ for m in tree_messages:
+ if is_descendant_of_deleted(m):
+ logger.debug(f"purging message: {m.id}")
+ self._purge_message_internal(m.id)
+
+ # update childern counts
+ self.pr.update_children_counts(m.message_tree_id)
+
+ # reactivate tree
+ logger.info(f"reactivating message tree {tree_id}")
+ mts = self.pr.fetch_tree_state(tree_id)
+ mts.active = True
+ self._enter_state(mts, message_tree_state.State.INITIAL_PROMPT_REVIEW)
+ self.check_condition_for_growing_state(tree_id)
+ self.check_condition_for_ranking_state(tree_id)
+ self.check_condition_for_scoring_state(tree_id)
+
+ @managed_tx_method(CommitMode.FLUSH)
+ def purge_user(self, user_id: UUID, ban: bool = True) -> None:
+ self.purge_user_messages(user_id, purge_initial_prompts=True)
+
+ # delete all remaining rows and ban user
+ sql_purge_user = """
+DELETE FROM journal WHERE user_id = :user_id;
+DELETE FROM message_reaction WHERE user_id = :user_id;
+DELETE FROM task WHERE user_id = :user_id;
+DELETE FROM message WHERE user_id = :user_id;
+DELETE FROM user_stats WHERE user_id = :user_id;
+"""
+
+ r = self.db.execute(text(sql_purge_user), {"user_id": user_id})
+ logger.debug(f"purge_user({user_id=}): {r.rowcount} rows.")
+
+ if ban:
+ self.db.execute(update(User).filter(User.id == user_id).values(deleted=True, enabled=False))
+
if __name__ == "__main__":
from oasst_backend.api.deps import api_auth
@@ -901,6 +1148,10 @@ if __name__ == "__main__":
tm = TreeManager(db, pr, cfg)
tm.ensure_tree_states()
+ tm.purge_user_messages(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"), purge_initial_prompts=False)
+ # tm.purge_user(user_id=UUID("2ef9ad21-0dc5-442d-8750-6f7f1790723f"))
+ # db.commit()
+
# print("query_num_active_trees", tm.query_num_active_trees())
# print("query_incomplete_rankings", tm.query_incomplete_rankings())
# print("query_replies_need_review", tm.query_replies_need_review())
@@ -909,10 +1160,10 @@ if __name__ == "__main__":
# print("query_extendible_parents", tm.query_extendible_parents())
# print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292")))
- print(
- "query_reviews_for_message",
- tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
- )
+ # print(
+ # "query_reviews_for_message",
+ # tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")),
+ # )
# print("next_task:", tm.next_task())
diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py
index bdd0e2e9..8b0c17f0 100644
--- a/backend/oasst_backend/user_stats_repository.py
+++ b/backend/oasst_backend/user_stats_repository.py
@@ -4,6 +4,7 @@ from uuid import UUID
import sqlalchemy as sa
from loguru import logger
+from oasst_backend.config import settings
from oasst_backend.models import Message, MessageReaction, Task, User, UserStats, UserStatsTimeFrame
from oasst_backend.models.db_payload import (
LabelAssistantReplyPayload,
@@ -39,12 +40,16 @@ class UserStatsRepository:
self.session.query(User.id.label("user_id"), User.username, User.auth_method, User.display_name, UserStats)
.join(UserStats, User.id == UserStats.user_id)
.filter(UserStats.time_frame == time_frame.value)
- .order_by(UserStats.leader_score.desc())
+ .order_by(UserStats.rank)
.limit(limit)
)
leaderboard = [_create_user_score(r) for r in self.session.exec(qry)]
- return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard)
+ if len(leaderboard) > 0:
+ last_update = max(x.modified_date for x in leaderboard)
+ else:
+ last_update = utcnow()
+ return LeaderboardStats(time_frame=time_frame.value, leaderboard=leaderboard, last_updated=last_update)
def get_user_stats_all_time_frames(self, user_id: UUID) -> dict[str, UserScore | None]:
qry = (
@@ -291,13 +296,11 @@ WHERE
if __name__ == "__main__":
- from oasst_backend.api.deps import get_dummy_api_client
+ from oasst_backend.api.deps import api_auth
from oasst_backend.database import engine
- with Session(engine) as session:
- api_client = get_dummy_api_client(session)
- usr = UserStatsRepository(session)
- # usr.update_all_time_frames()
- # session.commit()
- # usr.get_leader_board(UserStatsTimeFrame.total)
- usr.get_user_stats_all_time_frames(UUID("0d6ff62a-0bea-4c56-ade8-b3e0520a10ce"))
+ with Session(engine) as db:
+ api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
+ usr = UserStatsRepository(db)
+ usr.update_all_time_frames()
+ db.commit()
diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py
index d378c6a6..803b81e8 100644
--- a/backend/oasst_backend/utils/database_utils.py
+++ b/backend/oasst_backend/utils/database_utils.py
@@ -19,6 +19,7 @@ class CommitMode(IntEnum):
NONE = 0
FLUSH = 1
COMMIT = 2
+ ROLLBACK = 3
"""
@@ -41,6 +42,8 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=s
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
+ elif auto_commit == CommitMode.ROLLBACK:
+ self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
@@ -75,6 +78,8 @@ def async_managed_tx_method(
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
+ elif auto_commit == CommitMode.ROLLBACK:
+ self.db.rollback()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
@@ -118,8 +123,8 @@ def managed_tx_function(
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
- if isinstance(result, SQLModel):
- session.refresh(result)
+ elif auto_commit == CommitMode.ROLLBACK:
+ session.rollback()
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
diff --git a/backend/oasst_backend/utils/ranking.py b/backend/oasst_backend/utils/ranking.py
new file mode 100644
index 00000000..5538d7a3
--- /dev/null
+++ b/backend/oasst_backend/utils/ranking.py
@@ -0,0 +1,140 @@
+from typing import List
+
+import numpy as np
+
+
+def head_to_head_votes(ranks: List[List[int]]):
+ tallies = np.zeros((len(ranks[0]), len(ranks[0])))
+ names = sorted(ranks[0])
+ ranks = np.array(ranks)
+ # we want the sorted indices
+ ranks = np.argsort(ranks, axis=1)
+ for i in range(ranks.shape[1]):
+ for j in range(i + 1, ranks.shape[1]):
+ # now count the cases someone voted for i over j
+ over_j = np.sum(ranks[:, i] < ranks[:, j])
+ over_i = np.sum(ranks[:, j] < ranks[:, i])
+ tallies[i, j] = over_j
+ # tallies[i,j] = over_i
+ tallies[j, i] = over_i
+ # tallies[j,i] = over_j
+ return tallies, names
+
+
+def cycle_detect(pairs):
+ """Recursively detect cylces by removing condorcet losers until either only one pair is left or condorcet loosers no longer exist
+ This method upholds the invariant that in a ranking for all a,b either a>b or b>a for all a,b.
+
+
+ Returns
+ -------
+ out : False if the pairs do not contain a cycle, True if the pairs contain a cycle
+
+
+ """
+ # get all condorcet losers (pairs that loose to all other pairs)
+ # idea: filter all losers that are never winners
+ # print("pairs", pairs)
+ if len(pairs) <= 1:
+ return False
+ losers = [c_lose for c_lose in np.unique(pairs[:, 1]) if c_lose not in pairs[:, 0]]
+ if len(losers) == 0:
+ # if we recursively removed pairs, and at some point we did not have
+ # a condorcet loser, that means everything is both a winner and loser,
+ # yielding at least one (winner,loser), (loser,winner) pair
+ return True
+
+ new = []
+ for p in pairs:
+ if p[1] not in losers:
+ new.append(p)
+ return cycle_detect(np.array(new))
+
+
+def get_winner(pairs):
+ """
+ This returns _one_ concordant winner.
+ It could be that there are multiple concordant winners, but in our case
+ since we are interested in a ranking, we have to choose one at random.
+ """
+ losers = np.unique(pairs[:, 1]).astype(int)
+ winners = np.unique(pairs[:, 0]).astype(int)
+ for w in winners:
+ if w not in losers:
+ return w
+
+
+def get_ranking(pairs):
+ """
+ Abuses concordance property to get a (not necessarily unqiue) ranking.
+ The lack of uniqueness is due to the potential existence of multiple
+ equally ranked winners. We have to pick one, which is where
+ the non-uniqueness comes from
+ """
+ if len(pairs) == 1:
+ return list(pairs[0])
+ w = get_winner(pairs)
+ # now remove the winner from the list of pairs
+ p_new = np.array([(a, b) for a, b in pairs if a != w])
+ return [w] + get_ranking(p_new)
+
+
+def ranked_pairs(ranks: List[List[int]]):
+ """
+ Expects a list of rankings for an item like:
+ [("w","x","z","y") for _ in range(3)]
+ + [("w","y","x","z") for _ in range(2)]
+ + [("x","y","z","w") for _ in range(4)]
+ + [("x","z","w","y") for _ in range(5)]
+ + [("y","w","x","z") for _ in range(1)]
+ This code is quite brain melting, but the idea is the following:
+ 1. create a head-to-head matrix that tallies up all win-lose combinations of preferences
+ 2. take all combinations that win more than they loose and sort those by how often they win
+ 3. use that to create an (implicit) directed graph
+ 4. recursively extract nodes from the graph that do not have incoming edges
+ 5. said recursive list is the ranking
+ """
+ tallies, names = head_to_head_votes(ranks)
+ tallies = tallies - tallies.T
+ # print(tallies)
+ # note: the resulting tally matrix should be skew-symmetric
+ # order by strength of victory (using tideman's original method, don't think it would make a difference for us)
+ sorted_majorities = []
+ for i in range(len(ranks[0])):
+ for j in range(len(ranks[0])):
+ if tallies[i, j] > 0:
+ sorted_majorities.append((i, j, tallies[i, j]))
+ # we don't explicitly deal with tied majorities here
+ sorted_majorities = np.array(sorted(sorted_majorities, key=lambda x: x[2], reverse=True))
+ # now do lock ins
+ lock_ins = []
+ for (x, y, _) in sorted_majorities:
+ # invariant: lock_ins has no cycles here
+ lock_ins.append((x, y))
+ # print("lock ins are now",np.array(lock_ins))
+ if cycle_detect(np.array(lock_ins)):
+ # print("backup: cycle detected")
+ # if there's a cycle, delete the new addition and continue
+ lock_ins = lock_ins[:-1]
+ # now simply return all winners in order, and attach the losers
+ # to the back. This is because the overall loser might not be unique
+ # and (by concordance property) may never exist in any winning set to begin with.
+ # (otherwise he would either not be the loser, or cycles exist!)
+ # Since there could be multiple overall losers, we just return them in any order
+ # as we are unable to find a closer ranking
+ numerical_ranks = np.array(get_ranking(np.array(lock_ins))).astype(int)
+ conversion = [names[n] for n in numerical_ranks]
+ return conversion
+
+
+if __name__ == "__main__":
+ ranks = (
+ [("w", "x", "z", "y") for _ in range(1)]
+ + [("w", "y", "x", "z") for _ in range(2)]
+ # + [("x","y","z","w") for _ in range(4)]
+ + [("x", "z", "w", "y") for _ in range(5)]
+ + [("y", "w", "x", "z") for _ in range(1)]
+ # [("y","z","w","x") for _ in range(1000)]
+ )
+ rp = ranked_pairs(ranks)
+ print(rp)
diff --git a/docker-compose.yaml b/docker-compose.yaml
index 78192eb3..908457cd 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -18,7 +18,8 @@ services:
# This DB is for the FastAPI Backend.
db:
- image: postgres
+ image: ghcr.io/laion-ai/open-assistant/oasst-postgres
+ pull_policy: always
restart: always
ports:
- 5432:5432
diff --git a/docker/oasst-postgres/Dockerfile b/docker/oasst-postgres/Dockerfile
new file mode 100644
index 00000000..5c4aad80
--- /dev/null
+++ b/docker/oasst-postgres/Dockerfile
@@ -0,0 +1,11 @@
+FROM postgres:15
+
+# install unzip
+RUN apt-get update && apt-get install -y unzip curl && rm -rf /var/lib/apt/lists/*
+
+# download aws cli
+RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
+RUN unzip -q awscliv2.zip
+RUN ./aws/install
+
+COPY ./backup_pg_to_s3.sh .
diff --git a/docker/oasst-postgres/backup_pg_to_s3.sh b/docker/oasst-postgres/backup_pg_to_s3.sh
new file mode 100755
index 00000000..ff509947
--- /dev/null
+++ b/docker/oasst-postgres/backup_pg_to_s3.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+set -e
+set -x
+
+# filename with timestamp
+filename="postgres-$(date +%Y-%m-%d_%H-%M-%S).sql"
+
+# perform pg_dump
+pg_dump -U postgres postgres > /tmp/$filename
+
+# upload to s3
+aws s3 cp /tmp/$filename s3://$S3_BUCKET_NAME/$filename
+
+rm /tmp/$filename
diff --git a/docs/docs/faq/README.md b/docs/docs/faq/README.md
new file mode 100644
index 00000000..6b510c97
--- /dev/null
+++ b/docs/docs/faq/README.md
@@ -0,0 +1,3 @@
+# Frequently Asked Questions
+
+In this page, there are some of the most frequently asked questions.
diff --git a/docs/docs/faq/faq.md b/docs/docs/faq/faq.md
new file mode 100644
index 00000000..0db57d30
--- /dev/null
+++ b/docs/docs/faq/faq.md
@@ -0,0 +1,65 @@
+### Docker-Compose instead of Docker Compose
+
+If you are using `docker-compose` instead of `docker compose` (note the " "
+instead of the "-"), you should update your docker cli to the latest version.
+`docker compose` is the most recent version and should be used instead of
+`docker-compose`
+
+For more details and information check out
+[this SO thread](https://stackoverflow.com/questions/66514436/difference-between-docker-compose-and-docker-compose)
+that explains it all in detail.
+
+### Pre-commit
+
+We are using pre-commit to ensure the quality of the code as well as the same
+code standard.
+
+The steps that you need to follow to be able to use it are:
+
+```bash
+ # install pre-commit in your python environment
+ pip3 install pre-commit
+
+ # install pre-commit in your github configuration
+ pre-commit install
+```
+
+So from now on, in your next commits it will run the `pre-commit` on the files
+that have been staged. If there has been any error, you will need to solve that,
+and then stage+commit again the changes.
+
+## Docker Cannot Start Container: Permission Denied
+
+Instead of running docker with the root command always, you could create a
+`docker` group with granted permissions (root):
+
+```bash
+ # Create new linux user
+ sudo groupadd docker
+
+ # Add the actual user to the group
+ sudo usermod -aG docker $USER
+
+ # Log in the group (apply the group changes to actual terminal session)
+ newgrp docker
+```
+
+After that, you should be able to run docker: `docker run .`. In the case you
+still are not able, can try to reboot terminal:
+
+```bash
+ reboot
+```
+
+### Docker Cannot Stop Container
+
+If you try to shut down the services (`docker-compose down`), and you are
+getting permission denied (using root user), you can try the following:
+
+```bash
+ # Restart docker daemon
+ sudo systemctl restart docker.socket docker.service
+
+ # And remove the container
+ docker rm -f
+```
diff --git a/docs/sidebars.js b/docs/sidebars.js
index 2f7baedf..83063239 100644
--- a/docs/sidebars.js
+++ b/docs/sidebars.js
@@ -70,6 +70,15 @@ const sidebars = {
},
items: ["presentations/list"],
},
+ {
+ type: "category",
+ label: "FAQ",
+ link: {
+ type: "doc",
+ id: "faq/README",
+ },
+ items: ["faq/faq"],
+ },
],
};
diff --git a/model/supervised_finetuning/configs/config.yaml b/model/supervised_finetuning/configs/config.yaml
index 2eaa6686..815c2e75 100644
--- a/model/supervised_finetuning/configs/config.yaml
+++ b/model/supervised_finetuning/configs/config.yaml
@@ -30,6 +30,7 @@ defaults:
- joke
- gsm8k
- samsum
+ - soda_dialogue
cache_dir: .cache
loss_fn: CrossEntropyLoss
eval_size:
diff --git a/model/supervised_finetuning/custom_datasets/__init__.py b/model/supervised_finetuning/custom_datasets/__init__.py
index e293af3d..3bec37e7 100644
--- a/model/supervised_finetuning/custom_datasets/__init__.py
+++ b/model/supervised_finetuning/custom_datasets/__init__.py
@@ -1,5 +1,5 @@
from custom_datasets.prompt_dialogue import PromptGeneratedDataset
-from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, WebGPT
+from custom_datasets.qa_datasets import SODA, JokeExplaination, QADataset, SODADialogue, WebGPT
from custom_datasets.summarization import SummarizationDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
@@ -36,6 +36,9 @@ def get_one_dataset(conf, dataset_name):
elif dataset_name == "soda":
dataset = SODA(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.1)
+ elif dataset_name == "soda_dialogue":
+ dataset = SODADialogue(conf.cache_dir)
+ train, eval = train_val_dataset(dataset, val_split=0.1)
elif dataset_name == "joke":
dataset = JokeExplaination(conf.cache_dir)
train, eval = train_val_dataset(dataset, val_split=0.2)
diff --git a/model/supervised_finetuning/custom_datasets/qa_datasets.py b/model/supervised_finetuning/custom_datasets/qa_datasets.py
index eed9c644..d191c56c 100644
--- a/model/supervised_finetuning/custom_datasets/qa_datasets.py
+++ b/model/supervised_finetuning/custom_datasets/qa_datasets.py
@@ -106,7 +106,12 @@ class SODA(Dataset):
def process_soda_convo(self, data):
pairs = []
play_as = data["speakers"][1]
- prefix = "{}. {}".format(data["narrative"], "your name {}".format(play_as))
+ prefix = "{}{}. {}{}".format(
+ QA_SPECIAL_TOKENS["StartPrefix"],
+ data["narrative"],
+ "your name {}".format(play_as),
+ QA_SPECIAL_TOKENS["EndPrefix"],
+ )
question, answer = "", ""
prefix, postfix = "", ""
previous_chat = []
@@ -119,7 +124,9 @@ class SODA(Dataset):
answer = convo
postfix = data["speakers"][idx]
if len(question) and len(answer) and prefix != postfix and postfix == play_as:
- history = "".join(["{}{}".format(*p) for p in previous_chat])
+ history = "".join(
+ ["{}{}{}".format(p[0], QA_SPECIAL_TOKENS["Answer"], p[1]) for p in previous_chat]
+ )
if len(history):
history += ""
pairs.append((prefix + history + question, answer))
@@ -148,6 +155,57 @@ class SODA(Dataset):
return question, answer
+class SODADialogue(Dataset):
+ url = "https://drive.google.com/uc?id=1TOGQfr419n8wpzJpYLLw4nB3tSKD8zXV"
+
+ def __init__(self, cache_dir, verbose=True):
+
+ path = os.path.join(cache_dir, "soda_dialog.jsonl")
+
+ if not os.path.exists(path):
+ import gzip
+ import shutil
+
+ import gdown
+
+ gdown.download(self.url, output=os.path.join(cache_dir, "soda_dialog.jsonl.gz"))
+
+ with gzip.open(os.path.join(cache_dir, "soda_dialog.jsonl.gz"), "rb") as f_in:
+ with open(path, "wb") as f_out:
+ shutil.copyfileobj(f_in, f_out)
+
+ self.pairs = []
+ faulty = 0
+ with open(path) as fin:
+ for line in fin:
+ conversation = json.loads(line)
+ question_answer_pairs = ()
+
+ question_answers = conversation["text"].split("User: ")
+ for question_answer in question_answers[1:]: # first element is empty
+ try:
+ question, answer = question_answer.split("\nAssistant: ")
+ question_answer_pairs += (
+ question,
+ answer,
+ )
+ except ValueError:
+ # there might be some extra 'User: ' or 'Assistant: ' tokens in the dataset that cause trouble..
+ faulty += 1
+ continue
+
+ self.pairs.append(question_answer_pairs)
+
+ if verbose:
+ print("For SODA dialogue dataset found {} faults within the total {} dialogs".format(faulty, len(self)))
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def __getitem__(self, index):
+ return self.pairs[index]
+
+
class JokeExplaination(Dataset):
""" """
diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt
index 0e6eeb51..8f8cc63c 100644
--- a/model/supervised_finetuning/requirements.txt
+++ b/model/supervised_finetuning/requirements.txt
@@ -3,10 +3,12 @@ bitsandbytes==0.36.0.post2
datasets==2.8.0
deepspeed==0.7.7
evaluate==0.4.0
+gdown
mpi4py==3.1.4
nltk==3.8.1
-numpy==1.23.0
-PyYAML==6.0
+numpy>=1.22.4
+py7zr
+PyYAML>=6.0
scikit_learn==1.2.0
-torch==1.13.1
+torch>=1.11.0
transformers==4.25.1
diff --git a/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb b/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb
new file mode 100644
index 00000000..7fc89a06
--- /dev/null
+++ b/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb
@@ -0,0 +1,198 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## CodeT Code Generation Datasets\n",
+ "\n",
+ "[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_codegen.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This notebook contains code to parse CodeT code generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n",
+ "\n",
+ "Requirements: `requests`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "from pathlib import Path\n",
+ "import requests\n",
+ "from typing import Dict, List, Tuple"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATA_FILES: List[str] = [\n",
+ " \"HumanEval_for_code_generation.jsonl\",\n",
+ " \"mbpp_sanitized_for_code_generation.jsonl\",\n",
+ "]\n",
+ "\n",
+ "OUT_FILES: List[str] = [\n",
+ " \"HumanEval_codegen.jsonl\",\n",
+ " \"mbpp_codegen.jsonl\",\n",
+ "]\n",
+ "\n",
+ "FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n",
+ "\n",
+ "OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def download_file(filename: str):\n",
+ " url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n",
+ " response = requests.get(url)\n",
+ " with open(f\"data/{filename}\", \"wb\") as f:\n",
+ " f.write(response.content)\n",
+ "\n",
+ "\n",
+ "for filename in DATA_FILES:\n",
+ " download_file(filename)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can find the docstring, use its contents as the instruction (prefixed with \"Write a function corresponding to the docstring:\") and then use the content prior to the docstring and the canonical solution as the response."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n",
+ " docstring_start, docstring_end = None, None\n",
+ "\n",
+ " for i, line in enumerate(prompt_lines):\n",
+ " if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n",
+ " continue\n",
+ " if docstring_start:\n",
+ " docstring_end = i\n",
+ " break\n",
+ " docstring_start = i\n",
+ "\n",
+ " if docstring_end:\n",
+ " return docstring_start, docstring_end\n",
+ " raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n",
+ "\n",
+ "\n",
+ "def get_before(prompt_lines: List[str], before: int) -> List[str]:\n",
+ " before_lines = prompt_lines[:before]\n",
+ " return before_lines\n",
+ "\n",
+ "\n",
+ "def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n",
+ " between_lines = prompt_lines[start:end]\n",
+ " return between_lines"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_request_and_solution(sample: dict) -> Tuple[List[str], List[str]]:\n",
+ " prompt = sample[\"prompt\"]\n",
+ " prompt_lines = prompt.splitlines()\n",
+ "\n",
+ " docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n",
+ "\n",
+ " # Extract prompt\n",
+ " in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n",
+ " if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n",
+ " in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n",
+ " request = \"Write a Python function corresponding to the docstring: \" + \" \".join([p.strip() for p in in_docstring])\n",
+ "\n",
+ " # Extract solution\n",
+ " before_docstring = get_before(prompt_lines, docstring_start)\n",
+ " after_docstring = sample[\"canonical_solution\"].splitlines()\n",
+ " solution = before_docstring + after_docstring\n",
+ " # Gets rid of consecutive empty lines\n",
+ " solution = [v for i, v in enumerate(solution) if v != \"\" or v != solution[i - 1]]\n",
+ " solution = \"\\n\".join(solution)\n",
+ "\n",
+ " return request, solution"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def process_file(file_path: Path, out_path: Path):\n",
+ " lines = file_path.read_text().splitlines()\n",
+ " samples = list(map(json.loads, lines))\n",
+ "\n",
+ " output = []\n",
+ " for sample in samples:\n",
+ " prompt, solution = get_request_and_solution(sample)\n",
+ " output.append({\"prompt\": prompt, \"solution\": solution})\n",
+ "\n",
+ " with open(out_path, \"w\") as f:\n",
+ " for sample in output:\n",
+ " f.write(json.dumps(sample))\n",
+ " f.write(\"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n",
+ " process_file(file_path, out_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.5 ('venv': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.5"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb b/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb
new file mode 100644
index 00000000..c4641327
--- /dev/null
+++ b/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb
@@ -0,0 +1,191 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## CodeT Test Generation Datasets\n",
+ "\n",
+ "[](https://colab.research.google.com/github/LAION-AI/Open-Assistant/blob/main/notebooks/data-augmentation/codet-data/Augment_CodeT_testgen.ipynb)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This notebook contains code to parse CodeT test case generation prompt and solution data and modify to `(prompt, solution)` pairs outputted in a `.jsonl` file.\n",
+ "\n",
+ "Requirements: `requests`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "from pathlib import Path\n",
+ "import requests\n",
+ "from typing import List, Tuple"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "DATA_FILES: List[str] = [\n",
+ " \"HumanEval_for_test_case_generation.jsonl\",\n",
+ " \"mbpp_sanitized_for_test_case_generation.jsonl\",\n",
+ "]\n",
+ "\n",
+ "OUT_FILES: List[str] = [\n",
+ " \"HumanEval_testgen.jsonl\",\n",
+ " \"mbpp_testgen.jsonl\",\n",
+ "]\n",
+ "\n",
+ "FILE_PATHS: List[Path] = [Path(f\"data/{data_file}\") for data_file in DATA_FILES]\n",
+ "\n",
+ "OUT_PATHS: List[Path] = [Path(f\"data/augmented/{out_file}\") for out_file in OUT_FILES]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def download_file(filename: str):\n",
+ " url = f\"https://raw.githubusercontent.com/microsoft/CodeT/main/CodeT/data/dataset/{filename}\"\n",
+ " response = requests.get(url)\n",
+ " with open(f\"data/{filename}\", \"wb\") as f:\n",
+ " f.write(response.content)\n",
+ "\n",
+ "\n",
+ "for filename in DATA_FILES:\n",
+ " download_file(filename)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_docstring_indices(prompt_lines: List[str]) -> Tuple[int, int]:\n",
+ " docstring_start, docstring_end = None, None\n",
+ "\n",
+ " for i, line in enumerate(prompt_lines):\n",
+ " if not (line.strip().startswith('\"\"\"') or line.strip().startswith(\"'''\")):\n",
+ " continue\n",
+ " if docstring_start:\n",
+ " docstring_end = i\n",
+ " break\n",
+ " docstring_start = i\n",
+ "\n",
+ " if docstring_end:\n",
+ " return docstring_start, docstring_end\n",
+ " raise ValueError(f\"No complete docstring found!\\n{prompt_lines}\")\n",
+ "\n",
+ "\n",
+ "def get_between(prompt_lines: List[str], start: int, end: int) -> List[str]:\n",
+ " between_lines = prompt_lines[start:end]\n",
+ " return between_lines"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_request(sample: dict) -> List[str]:\n",
+ " prompt = sample[\"prompt\"]\n",
+ " prompt_lines = prompt.splitlines()\n",
+ "\n",
+ " docstring_start, docstring_end = get_docstring_indices(prompt_lines)\n",
+ "\n",
+ " # Extract prompt\n",
+ " in_docstring = get_between(prompt_lines, docstring_start, docstring_end)\n",
+ " if '\"\"\"' in in_docstring[0] or \"'''\" in in_docstring[0]:\n",
+ " in_docstring[0] = in_docstring[0].replace('\"\"\"', \"\").replace(\"...\", \"\").strip()\n",
+ " request = \"Write a test for a Python function with the following docstring: \" + \" \".join(\n",
+ " [p.strip() for p in in_docstring]\n",
+ " )\n",
+ "\n",
+ " return request\n",
+ "\n",
+ "\n",
+ "def get_test_code(sample: dict) -> List[str]:\n",
+ " test = sample[\"test\"]\n",
+ " test_lines = test.splitlines()\n",
+ " start = 0\n",
+ " for i, line in enumerate(test_lines):\n",
+ " if \"def check(\" in line:\n",
+ " start = i\n",
+ " return \"\\n\".join(test_lines[start:])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def process_file(file_path: Path, out_path: Path):\n",
+ " lines = file_path.read_text().splitlines()\n",
+ " samples = list(map(json.loads, lines))\n",
+ "\n",
+ " output = []\n",
+ " for sample in samples:\n",
+ " prompt = get_request(sample)\n",
+ " test = get_test_code(sample)\n",
+ " output.append({\"prompt\": prompt, \"solution\": test})\n",
+ "\n",
+ " with open(out_path, \"w\") as f:\n",
+ " for sample in output:\n",
+ " f.write(json.dumps(sample))\n",
+ " f.write(\"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for file_path, out_path in zip(FILE_PATHS, OUT_PATHS):\n",
+ " process_file(file_path, out_path)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.10.5 ('venv': venv)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.5"
+ },
+ "orig_nbformat": 4,
+ "vscode": {
+ "interpreter": {
+ "hash": "1f9a0efd3e4a33b8f30a65df6ca5a95cc3f93ce2f11519ee8c13fe711de61465"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/data-augmentation/codet-data/README.md b/notebooks/data-augmentation/codet-data/README.md
new file mode 100644
index 00000000..985883b4
--- /dev/null
+++ b/notebooks/data-augmentation/codet-data/README.md
@@ -0,0 +1,15 @@
+# CodeT Datasets
+
+This folder contains two notebooks.
+
+One will download the data used for Microsoft CodeT for tuning a model for
+Python code generation from function docstrings, augment the data into prompt
+and solution pairs and write them to `.jsonl` files.
+
+The other will download the data used for Microsoft CodeT for tuning a model for
+Python test generation from corresponding function docstrings, augment the data
+into prompt and solution pairs and write them to `.jsonl` files.
+
+## Requirements
+
+Both notebooks require the library `requests`.
diff --git a/notebooks/diverse/README.md b/notebooks/diverse/README.md
new file mode 100644
index 00000000..a56806c6
--- /dev/null
+++ b/notebooks/diverse/README.md
@@ -0,0 +1,11 @@
+# DIVERSE Downloader
+
+Diverse is a notebook that downloads the DIVERSE dataset and converts it into
+OpenAssistant Data Scheme formats.
+
+---
+
+## Contributing
+
+Feel free to contribute to this notebook. It's not perfect and additional
+functionality is planned.
diff --git a/notebooks/diverse/diverse.ipynb b/notebooks/diverse/diverse.ipynb
new file mode 100644
index 00000000..c4c6a518
--- /dev/null
+++ b/notebooks/diverse/diverse.ipynb
@@ -0,0 +1,477 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "00b2848c",
+ "metadata": {},
+ "source": [
+ "# Diverse Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "[](https://colab.research.google.com/drive/1CmXjXVrmPtpAVBaogBSuDclM0O6Zzewf?usp=sharing)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d81932b9",
+ "metadata": {},
+ "source": [
+ "The purpose of this notebook is to download the DIVERSE dataset and convert it into a format that can be used for training the OpenAssistant.\n",
+ "\n",
+ "The DIVERSE repo can be found here: https://github.com/microsoft/CodeT/tree/main/DIVERSE\n",
+ "\n",
+ "If you extend or use this work, please cite the relevant papers:\n",
+ "```\n",
+ "@article{li2022advance,\n",
+ " title={On the Advance of Making Language Models Better Reasoners},\n",
+ " author={Li, Yifei and Lin, Zeqi and Zhang, Shizhuo and Fu, Qiang and Chen, Bei and Lou, Jian-Guang and Chen, Weizhu},\n",
+ " journal={arXiv preprint arXiv:2206.02336},\n",
+ " year={2022}\n",
+ "}\n",
+ "\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a8c98078",
+ "metadata": {},
+ "source": [
+ "# OpenAssistant Data Scheme"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2731f88f",
+ "metadata": {},
+ "source": [
+ "We will use the data scheme that can be found in the docs for Open-Assistant. This code is taken from the StackExchange notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "d35ab066",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import TypeVar, List, Dict, Any, Literal\n",
+ "from json import JSONEncoder\n",
+ "\n",
+ "T = TypeVar(\"T\", bound=\"ConversationTreeNode\")\n",
+ "\n",
+ "\n",
+ "class ConversationTreeNode:\n",
+ " text: str # The text of the node\n",
+ " role: Literal[\"prompter\", \"assistant\"] # Whether the node is a user prompt/follow-up or an assistant response\n",
+ " children: List[T] # The children of the node (if you have a linear conversation, this will be of length 0 or 1)\n",
+ " metadata: Dict[str, Any] # Node metadata (see below)\n",
+ "\n",
+ " def __init__(\n",
+ " self, text: str, role: Literal[\"prompter\", \"assistant\"], children: List[T], metadata: Dict[str, Any]\n",
+ " ) -> None:\n",
+ " self.text = text\n",
+ " self.role = role\n",
+ " self.children = children\n",
+ " self.metadata = metadata\n",
+ "\n",
+ "\n",
+ "class ConversationTree:\n",
+ " root: ConversationTreeNode # The node containing the initial prompt\n",
+ " metadata: Dict[str, Any] # Tree metadata, different from root node metadata.\n",
+ "\n",
+ " def __init__(self, root: ConversationTreeNode, metadata: Dict[str, Any]) -> None:\n",
+ " self.root = root\n",
+ " self.metadata = metadata\n",
+ "\n",
+ "\n",
+ "# subclass JSONEncoder\n",
+ "class TreeEncoder(JSONEncoder):\n",
+ " def default(self, o):\n",
+ " return o.__dict__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e7457bae",
+ "metadata": {},
+ "source": [
+ "# Download and convert"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "54b0fd63",
+ "metadata": {},
+ "source": [
+ "We firstly import pandas and any other libraries that we'll need."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "9317d4b4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import json"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "62dc4e18",
+ "metadata": {},
+ "source": [
+ "The following is a simple function to take the data (which has two columns) and convert it to a tree with a root note (question) and one child (answer)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "963e0d92",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "\n",
+ "def convert_diverse(dataset_json_path):\n",
+ " # read files using pandas\n",
+ " ds = pd.read_json(dataset_json_path, lines=True)\n",
+ "\n",
+ " # create dataset name from path\n",
+ " ds_name = \"diverse\" + file.split(\"data\")[-1].replace(\"/\", \"_\").split(\".\")[0]\n",
+ " print(\"*****\", ds_name, \"****\")\n",
+ " print(\"Example of raw dataset\")\n",
+ " print(ds.head(2))\n",
+ "\n",
+ " # create conversation forest\n",
+ " # Print first sample so the user of this notebook has an idea of what he's looking at\n",
+ " first_sample = True\n",
+ " print(\"\\nExamples from converted dataset\")\n",
+ " conversation_forest = []\n",
+ " for item in ds[\"context\"]:\n",
+ " # build nodes and tree\n",
+ " # Find all answers:\n",
+ "\n",
+ " answers = re.findall(r\"Answer:?(.*?)#\", item.replace(\"\\n\", \" \"))\n",
+ " questions = re.findall(r\"Question:?(.*?) Answer:\", item.replace(\"\\n\", \" \"))\n",
+ "\n",
+ " # The last question does not contain an aswer so we drop it every time.\n",
+ " if len(answers) < len(questions):\n",
+ " questions.pop(-1)\n",
+ "\n",
+ " for (answer, question) in zip(answers, questions):\n",
+ " if first_sample:\n",
+ " print(f\"Q: {question}\")\n",
+ " print(f\"A: {answer}\")\n",
+ " root = ConversationTreeNode(text=question, role=\"prompter\", children=[], metadata=None)\n",
+ " child = ConversationTreeNode(text=answer, role=\"assistant\", children=[], metadata=None)\n",
+ " root.children.append(child)\n",
+ " conversation_tree = ConversationTree(root=root, metadata={\"dataset\": ds_name})\n",
+ " conversation_forest.append(conversation_tree)\n",
+ "\n",
+ " first_sample = False\n",
+ "\n",
+ " conversation_forest_json = [\n",
+ " json.loads(TreeEncoder().encode(conversation_tree)) for conversation_tree in conversation_forest\n",
+ " ]\n",
+ "\n",
+ " print(json.dumps(conversation_forest_json, indent=4), file=open(f\"./{ds_name}.json\", \"w+\"))\n",
+ " print(\"\\n\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4448c9a",
+ "metadata": {},
+ "source": [
+ "We now clone the repository containing the dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "06e7719e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cloning into 'CodeT'...\r\n",
+ "remote: Enumerating objects: 144, done.\u001B[K\r\n",
+ "remote: Counting objects: 100% (16/16), done.\u001B[K\r\n",
+ "remote: Compressing objects: 100% (16/16), done.\u001B[K\r\n",
+ "remote: Total 144 (delta 1), reused 0 (delta 0), pack-reused 128\u001B[Kving objects: 8% (12/144), 3.70 MiB | 1.35 MiB/s Receiving objects: 12% (18/144), 5.13 MiB | 1.58 MiB/s Receiving objects: 13% (19/144), 11.36 MiB | 2.39 MiB/s Receiving objects: 13% (20/144), 19.19 MiB | 3.97 MiB/s Receiving objects: 15% (22/144), 19.19 MiB | 3.97 MiB/s Receiving objects: 23% (34/144), 22.15 MiB | 4.37 MiB/s Receiving objects: 27% (39/144), 22.15 MiB | 4.37 MiB/s Receiving objects: 29% (42/144), 25.30 MiB | 4.77 MiB/s Receiving objects: 32% (47/144), 28.71 MiB | 5.22 MiB/s Receiving objects: 41% (60/144), 32.41 MiB | 5.62 MiB/s Receiving objects: 54% (78/144), 32.41 MiB | 5.62 MiB/s Receiving objects: 60% (87/144), 39.34 MiB | 6.20 MiB/s Receiving objects: 61% (88/144), 47.14 MiB | 6.79 MiB/s Receiving objects: 64% (93/144), 51.36 MiB | 7.12 MiB/s Receiving objects: 66% (96/144), 55.58 MiB | 7.40 MiB/s \r\n",
+ "Receiving objects: 100% (144/144), 56.76 MiB | 4.97 MiB/s, done.\r\n",
+ "Resolving deltas: 100% (33/33), done.\r\n",
+ "Checking out files: 100% (64/64), done.\r\n"
+ ]
+ }
+ ],
+ "source": [
+ "!git clone https://github.com/microsoft/CodeT.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "89a166c9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "diverse_files = [\n",
+ " \"CodeT/DIVERSE/data/sqa/split1/test.jsonl\",\n",
+ " \"CodeT/DIVERSE/data/sqa/split1/train.jsonl\",\n",
+ " \"CodeT/DIVERSE/data/sqa/split2/test.jsonl\",\n",
+ " \"CodeT/DIVERSE/data/sqa/split2/train.jsonl\",\n",
+ " \"CodeT/DIVERSE/data/gsm8k/test.jsonl\",\n",
+ " \"CodeT/DIVERSE/data/gsm8k/train.jsonl\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "2da75b14",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "***** diverse_sqa_split1_test ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nIs clerk of Supreme Court of Canada... \n",
+ "1 Question:\\nIs Saturn named after king of gods ... \n",
+ "\n",
+ " samples \\\n",
+ "0 [Snoopy is a dog.\\nChance is a dog.\\nDogs look... \n",
+ "1 [Snoopy is a cartoon dog.\\nChance is a cartoon... \n",
+ "\n",
+ " metadata \n",
+ "0 {'type': 'solution', 'question': 'Does Snoopy ... \n",
+ "1 {'type': 'solution', 'question': 'Does Snoopy ... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: Is clerk of Supreme Court of Canada safe profession for someone with seismophobia?\n",
+ "A: The Supreme Court of Canada is in Ottawa, Canada. Ottawa is in Ontario, Canada. Ontario is in the Canadian Shield. The Canadian Shield is a stable tectonic plate. Thus, Ottawa is not prone to earthquakes. Thus, the clerk of the Supreme Court of Canada is a safe profession for someone with seismophobia. So the answer is yes.\n",
+ "Q: During the Cuban revolution, did the US experience a population boom?\n",
+ "A: The Cuban revolution was in 1959. The US population in 1959 was about 180 million. The US population in 2010 was about 310 million. Thus, the US population increased by about 130 million. So the answer is yes.\n",
+ "Q: Can the largest crustacean stretch out completely on a king-sized mattress?\n",
+ "A: The largest crustacean is the Japanese spider crab. The Japanese spider crab has a leg span of 3.8 meters. A king-sized mattress is 1.9 meters wide. Thus, the Japanese spider crab could not stretch out completely on a king-sized mattress. So the answer is no.\n",
+ "Q: Could morphine cure HIV?\n",
+ "A: Morphine is a painkiller. HIV is a virus. Painkillers do not cure viruses. Thus, morphine could not cure HIV. So the answer is no.\n",
+ "Q: Is Christopher Nolan indebted to Bob Kane?\n",
+ "A: Bob Kane created Batman. Christopher Nolan directed the Dark Knight trilogy. The Dark Knight trilogy is about Batman. Thus, Christopher Nolan is indebted to Bob Kane. So the answer is yes.\n",
+ "Q: In baseball, is a \"Homer\" named after the poet Homer who wrote the Odyssey?\n",
+ "A: A \"Homer\" is a home run. The term \"home run\" was coined by Harry Wright in 1858. Homer lived from about 800 BC to about 700 BC. Thus, the term \"home run\" was not named after the poet Homer. So the answer is no.\n",
+ "\n",
+ "\n",
+ "***** diverse_sqa_split1_train ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nIs eating a Dicopomorpha echmeptery... \n",
+ "1 Question:\\nWas Alexander the Great baptized?\\n... \n",
+ "\n",
+ " samples \\\n",
+ "0 [New York Public Library is a library.\\nSix Fl... \n",
+ "1 [The New York Public Library is in New York Ci... \n",
+ "\n",
+ " metadata \n",
+ "0 {'type': 'solution', 'question': 'Could you go... \n",
+ "1 {'type': 'solution', 'question': 'Could you go... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: Is eating a Dicopomorpha echmepterygis size Uranium pellet fatal?\n",
+ "A: The Dicopomorpha echmepterygis is a beetle. The mass of a beetle is about 0.1 g. The mass of a Uranium pellet is about 10 g. Thus, eating a Dicopomorpha echmepterygis is not the same as eating a Uranium pellet. So the answer is no.\n",
+ "Q: Does Jerry Seinfeld hang out at the Budweiser Party Deck?\n",
+ "A: Jerry Seinfeld is a comedian. Comedians perform at comedy clubs. The Budweiser Party Deck is not a comedy club. Thus, Jerry Seinfeld does not hang out at the Budweiser Party Deck. So the answer is no.\n",
+ "Q: Does The Jungle Book contain racist subtext?\n",
+ "A: The Jungle Book is a book about a boy raised by wolves. The boy is white. The wolves are black. Thus, The Jungle Book contains racist subtext. So the answer is yes.\n",
+ "Q: Does Dean Cain have less days to birthday than Will Ferrell every 4th of July?\n",
+ "A: Dean Cain was born on July 31. Will Ferrell was born on July 16. July 4 is the middle of July. Thus, Dean Cain has more days to his birthday than Will Ferrell every 4th of July. So the answer is no.\n",
+ "Q: Do mail carriers need multiple uniforms?\n",
+ "A: Mail carriers need to wear uniforms. Uniforms are clothes. People need multiple clothes. Thus, mail carriers need multiple uniforms. So the answer is yes.\n",
+ "Q: Can Curiosity take samples of rocks from Lacus Temporis?\n",
+ "A: Curiosity is a rover on Mars. Lacus Temporis is a lake on Mars. Curiosity is not in Lacus Temporis. Thus, Curiosity cannot take samples of rocks from Lacus Temporis. So the answer is no.\n",
+ "\n",
+ "\n",
+ "***** diverse_sqa_split2_test ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nDoes the density of helium cause vo... \n",
+ "1 Question:\\nIs Mark Cuban able to visit Norther... \n",
+ "\n",
+ " samples \\\n",
+ "0 [The New York Public Library is in New York Ci... \n",
+ "1 [The distance from New York Public Library to ... \n",
+ "\n",
+ " metadata \n",
+ "0 {'type': 'solution', 'question': 'Could you go... \n",
+ "1 {'type': 'solution', 'question': 'Could you go... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: Does the density of helium cause voices to sound deeper?\n",
+ "A: The density of helium is 0.1785 g/L, which is less than air. The density of air is 1.225 g/L. The density of helium is less than air, so helium causes voices to sound higher. Thus, helium does not cause voices to sound deeper. So the answer is no.\n",
+ "Q: Would Janet Jackson avoid a dish with ham?\n",
+ "A: Janet Jackson is a vegetarian. Vegetarians do not eat meat. Ham is a type of meat. Thus, Janet Jackson would avoid a dish with ham. So the answer is yes.\n",
+ "Q: Do people watching Coen brothers films in Guinea Bissau need subtitles?\n",
+ "A: The Coen brothers are American. Americans speak English. Guinea Bissau is in Africa. Africans speak Portuguese. Thus, people in Guinea Bissau would need subtitles to watch Coen brothers films. So the answer is yes.\n",
+ "Q: Could a Gladiator's weapon crush a diamond?\n",
+ "A: A gladiator's weapon was a sword. The hardness of a diamond is 10. The hardness of a sword is 5. Thus, a gladiator's weapon could not crush a diamond. So the answer is no.\n",
+ "Q: Can Spartina Patens thrive in the Sahara Desert?\n",
+ "A: Spartina Patens is a salt marsh grass. The Sahara Desert is a desert. Deserts are dry. Thus, Spartina Patens would not thrive in the Sahara Desert. So the answer is no.\n",
+ "Q: Were all the materials to make a cannon known during the bronze age?\n",
+ "A: The bronze age was about 3000 BC to 1200 BC. The bronze age was before the iron age. Thus, iron was not known during the bronze age. Iron is needed to make a cannon. Thus, all the materials to make a cannon were not known during the bronze age. So the answer is no.\n",
+ "\n",
+ "\n",
+ "***** diverse_sqa_split2_train ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nWas the Peak of the Andes hidden fr... \n",
+ "1 Question:\\nIs Mark Cuban able to visit Norther... \n",
+ "\n",
+ " samples \\\n",
+ "0 [Snoopy is a dog.\\nChance is a dog.\\nChance is... \n",
+ "1 [Snoopy is a cartoon dog.\\nChance is a cartoon... \n",
+ "\n",
+ " metadata \n",
+ "0 {'type': 'solution', 'question': 'Does Snoopy ... \n",
+ "1 {'type': 'solution', 'question': 'Does Snoopy ... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: Was the Peak of the Andes hidden from the view of the Colossus of Rhodes?\n",
+ "A: The Colossus of Rhodes was a statue of the Greek god Helios. The Colossus of Rhodes was located on the island of Rhodes. The Peak of the Andes is in South America. Thus, the Peak of the Andes was not visible from the Colossus of Rhodes. So the answer is yes.\n",
+ "Q: Can you swim to Miami from New York?\n",
+ "A: The distance from New York to Miami is about 1,500 miles. The fastest swimmer can swim about 2 miles per hour. Thus, it would take about 750 hours to swim from New York to Miami. Thus, you could not swim to Miami from New York. So the answer is no.\n",
+ "Q: Is Freya a combination of Athena and Aphrodite?\n",
+ "A: Freya is the Norse goddess of love and beauty. Athena is the Greek goddess of wisdom and war. Aphrodite is the Greek goddess of love and beauty. Thus, Freya is a combination of Athena and Aphrodite. So the answer is yes.\n",
+ "Q: Were the Great Pyramids built by a theocratic government?\n",
+ "A: The Great Pyramids were built by the Egyptians. The Egyptians were ruled by a pharaoh. The pharaoh was considered a god. Thus, the Great Pyramids were built by a theocratic government. So the answer is yes.\n",
+ "Q: Was P. G. Wodehouse's favorite book The Hunger Games?\n",
+ "A: P. G. Wodehouse's favorite book was The Pickwick Papers. The Pickwick Papers was written by Charles Dickens. The Hunger Games was written by Suzanne Collins. Thus, P. G. Wodehouse's favorite book was not The Hunger Games. So the answer is no.\n",
+ "Q: Can Burundi's communicate with citizens of New Brunswick?\n",
+ "A: Burundi's speak French. New Brunswick is in Canada. Canada is a bilingual country. Thus, Burundi's can communicate with citizens of New Brunswick. So the answer is yes.\n",
+ "\n",
+ "\n",
+ "***** diverse_gsm8k_test ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nA community is building a metal fen... \n",
+ "1 Question:\\nThe white rabbit can hop 15 meters ... \n",
+ "\n",
+ " samples \\\n",
+ "0 [She eats 3 and uses 4, so that is 7 eggs.\\n16... \n",
+ "1 [She eats 3 eggs for breakfast and uses 4 in m... \n",
+ "\n",
+ " metadata \n",
+ "0 {'question': 'Janet’s ducks lay 16 eggs per da... \n",
+ "1 {'question': 'Janet’s ducks lay 16 eggs per da... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: A community is building a metal fence. Each fence panel is made of 3 metal sheets, and 2 metal beams. The fence is made of 10 fence panels. If each sheet is made of 10 metal rods and each metal beam is made of 4 metal rods, how many metal rods does the community need for the fence?\n",
+ "A: In each panel, the metal sheets use 3 metal sheets * 10 metal rods = <<3*10=30>>30 metal rods. In each panel, the metal beams use 2 metal beams * 4 metal rods = <<2*4=8>>8 metal rods. So each panel uses 30 + 8 = <<30+8=38>>38 metal rods. The entire fence therefore needs 38 metal rods * 10 fence panels = <<38*10=380>>380 metal rods. \n",
+ "Q: John buys 3 dress shirts. They sell for $20 each. He also has to pay 10% tax on everything. How much did he pay in total?\n",
+ "A: The shirts cost 3*$20=$<<3*20=60>>60 before tax The tax cost $60*.1=$<<60*.1=6>>6 So in total they paid $60+$6=$<<60+6=66>>66 \n",
+ "Q: Bob gets rent assistance because he's low-income. If he gets a raise of $0.50/hour and works 40 hours a week, how much more will he actually earn a week if his housing benefit is reduced by $60/month?\n",
+ "A: First find the total increase in Bob's earnings: $0.50/hour * 40 hours/week = $<<0.50*40=20>>20/week Then find the weekly decrease in Bob's housing assistance: $60/month / 4 weeks/month = $<<60/4=15>>15/week Then subtract the lost assistance from the increased wages to find Bob's net increase in money: $20/week - $15/week = $<<20-15=5>>5/week \n",
+ "Q: Annie plants 3 pots of basil, 9 pots of rosemary, and 6 pots of thyme. Each basil plant has 4 leaves, each rosemary plant has 18 leaves, and each thyme plant has 30 leaves. How many leaves are there total?\n",
+ "A: First find the total number of basil leaves: 3 pots * 4 leaves/pot = <<3*4=12>>12 leaves Then find the total number of rosemary leaves: 9 pots * 18 leaves/pot = <<9*18=162>>162 leaves Then find the total number of thyme leaves: 6 pots * 30 leaves/pot = <<6*30=180>>180 leaves Then add the number of each type of leaf to find the total number of leaves: 12 leaves + 162 leaves + 180 leaves = <<12+162+180=354>>354 leaves \n",
+ "Q: There are 7 mL of solution in each of 6 test tubes. Dr. Igor takes all of the solution and then evenly distributes it into 3 beakers. How many mL of solution are in each beaker?\n",
+ "A: 7 * 6 = <<7*6=42>>42 mL 42/3 = <<42/3=14>>14 mL Each beaker holds 14 mL of solution. \n",
+ "Q: Janet pays $40/hour for 3 hours per week of clarinet lessons and $28/hour for 5 hours a week of piano lessons. How much more does she spend on piano lessons than clarinet lessons in a year?\n",
+ "A: First find the total Janet spends on clarinet lessons per week: $40/hour * 3 hours/week = $<<40*3=120>>120/week Then find the total Janet spends on piano lessons per week: $28/hour * 5 hours/week = $<<28*5=140>>140/week Then subtract her weekly clarinet spending from her weekly piano spending to find the weekly difference: $140/week - $120/week = $<<140-120=20>>20/week Then multiply the weekly difference by the number of weeks in a year to find the annual difference: $20/week * 52 weeks/year = $<<20*52=1040>>1040/year \n",
+ "Q: A normal lemon tree produces 60 lemons per year. Jim has specially engineered lemon trees that produce 50% more lemons per year. He has a grove that is 50 trees by 30 trees. How many lemons does he produce in 5 years?\n",
+ "A: Each tree produces 60*.5=<<60*.5=30>>30 more lemons than normal So they each produce 60+30=<<60+30=90>>90 lemons He has 50*30=<<50*30=1500>>1500 trees So every year he produces 1500*90=<<1500*90=135000>>135000 lemons That means he produces 135000*5=<<135000*5=675000>>675,000 \n",
+ "Q: Billy weighs 9 pounds more than Brad. Brad weighs 5 pounds more than Carl. If Carl weighs 145 pounds, how much does Billy weigh, in pounds?\n",
+ "A: Brad weighs 145+5=<<145+5=150>>150 pounds. Billy weighs 150+9=<<150+9=159>>159 pounds. \n",
+ "\n",
+ "\n",
+ "***** diverse_gsm8k_train ****\n",
+ "Example of raw dataset\n",
+ " context \\\n",
+ "0 Question:\\nA magician has a top hat with 20 re... \n",
+ "1 Question:\\nA community is building a metal fen... \n",
+ "\n",
+ " samples \\\n",
+ "0 [There are 125 cars in total\\n64% of them are ... \n",
+ "1 [64% of 125 is <<64/100*125=80>>80.\\n8% of 125... \n",
+ "\n",
+ " metadata \n",
+ "0 {'question': 'Pauline has 125 matchbox cars. T... \n",
+ "1 {'question': 'Pauline has 125 matchbox cars. T... \n",
+ "\n",
+ "Examples from converted dataset\n",
+ "Q: A magician has a top hat with 20 red marbles and a top hat with 30 blue marbles. If he takes away 3 red marbles and four times as many blue marbles as red marbles (without looking), how many marbles in total does he have left?\n",
+ "A: He had 20 red marbles and took away 3 leaving 20-3 = <<20-3=17>>17 red marbles He took 4 times as many blue marbles as red marbles which is 4*3 = <<4*3=12>>12 blue marbles He took 12 blue marbles from 30 leaving 30-12 = 18 blue marbles He now has 17+18 = <<17+18=35>>35 marbles left \n",
+ "Q: Lucas wants to get a dog but his parents think he already has too many pets and won't have enough space. He already has 12 pet beds in his room but manages to fit another 8 pet beds. His parents argue that each pet is going to need 2 beds each to feel comfortable. According to his parent's argument, how many pets does Lucas have enough room for?\n",
+ "A: Lucas has a total of 12 existing pet beds + 8 new pet beds = <<12+8=20>>20 pet beds. So according to his parents, Lucas has enough room for 20 pet beds / 2 pet beds per pet = <<20/2=10>>10 pets. \n",
+ "Q: Super Clean Car Wash Company cleans 80 cars per day. They make $5 per car washed. How much money will they make in 5 days?\n",
+ "A: Each day they will make 80 Ă— $5 = $<<80*5=400>>400. They will make $400 Ă— 5 = $<<400*5=2000>>2000 in 5 days. \n",
+ "Q: Eighteen hours ago, Beth and I took 100 photographs of our project. Today, Beth and I will take 20% fewer photographs of the same project. If we were to take 300 photographs of the project, how many photographs would we take to reach the target?\n",
+ "A: If you took 100 photographs of the project 18 hours ago, and today 20% few photographs have been taken, then 20/100*100 = 20 fewer photographs of the project have been taken today. The total number of photographs of the project that have been taken today is 100-20 = <<100-20=80>>80 So far, you've taken 80+100 = <<80+100=180>>180 photographs of the project. Since the target number of photographs is 300, the number of photographs that you need to take to reach the target is 300-180 = <<300-180=120>>120 \n",
+ "Q: Ruby was going to order pizza for dinner. Her son would only eat pepperoni pizza. Her daughter would only eat sausage. Ruby and her husband wanted black olive and mushroom pizza. To make life easy, Ruby decided to order an entire pizza for each of her children and she would split one with her husband. The pizza restaurant charged $10.00 per pizza and $1.00 per topping. She also needed to add a $5.00 tip. Including tip, how much would the pizza order cost?\n",
+ "A: Ruby was going to order 1 for her son, 1 for her daughter and 1 to share with her husband. So she needed to order 1+1+1 = <<1+1+1=3>>3 pizzas Each pizza cost $10 and she was ordering 3 so that comes to 10*3 = $<<10*3=30.00>>30.00 She needed to order pepperoni, sausage, black olive and mushroom, which came to 4 toppings at $1.00 each so 4*1 = $<<4*1=4.00>>4.00 extra for toppings The pizzas cost $30 and $4 for the toppings so the total costs of the pizzas came to 30+4 = $<<30+4=34.00>>34.00 She also had to add a $5.00 tip to her current order total of $34.00 so 5+34.00 = $<<5+34=39.00>>39.00 for the total order \n",
+ "Q: Yves and his siblings ordered pizza and asked to have it cut into 16 slices. During dinner time, they only ate one-fourth of it. The next day, Yves ate one-fourth of the remaining pizza. Then his two siblings ate 2 slices each. How many slices of pizza were left?\n",
+ "A: During dinner time, Yves and his siblings ate 16/4 = <<16/4=4>>4 slices. So the next day, there were still 16 - 4 = <<16-4=12>>12 slices left. The next day, Yves ate 12/4 = <<12/4=3>>3 slices of pizza. Thus, there were 12 - 3 = <<12-3=9>>9 slices left. Then, his two siblings ate 2 x 2 = <<2*2=4>>4 slices of pizza. Therefore, there were still 9 - 4 = <<9-4=5>>5 slices of pizza left. \n",
+ "Q: Pria bought a new car that advertised an estimated gas mileage of 35 miles per gallon. The car has a 12-gallon tank. She filled her car full of gas and was able to drive a total of 372 miles. What was the difference, in miles per gallon, between Pria's mileage and the advertised mileage?\n",
+ "A: Pria's car achieved a rate of 372 miles / 12 gallons = <<372/12=31>>31 miles per gallon. Therefore, it was a difference of 35 - 31 = <<35-31=4>>4 miles per gallon. \n",
+ "Q: Janet has 24 dresses. Half of them have pockets. Of those, a third have 2 pockets and the rest have 3 pockets. How many total pockets do her dresses have?\n",
+ "A: She has 24/2=<<24/2=12>>12 dresses with pockets Of those 12/3=4 have 2 pockets So 12-4=<<12-4=8>>8 have three pockets So the dresses with 2 pockets have 2*4=<<2*4=8>>8 pockets The other dresses contribute 8*3=<<8*3=24>>24 pockets So she has a total of 8+24=<<8+24=32>>32 pockets \n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "for file in diverse_files:\n",
+ " convert_diverse(file)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "name": "conda-env-open-assistant-py",
+ "language": "python",
+ "display_name": "Python [conda env:open-assistant]"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.0"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "25d5c2324055587ceaeef27650c79ce8358ea61d7689f2e0b8ada5d53f85bce4"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py
index e60ad746..e8cd2359 100644
--- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py
+++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py
@@ -40,7 +40,7 @@ class OasstErrorCode(IntEnum):
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
- USER_NOT_SPECIFIED = 2005
+
NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
INVALID_MESSAGE = 2008
@@ -62,11 +62,15 @@ class OasstErrorCode(IntEnum):
TASK_NOT_COLLECTIVE = 2106
TASK_NOT_ASSIGNED_TO_USER = 2106
TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107
- USER_NOT_FOUND = 2200
# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001
+ # 4000-5000: user
+ USER_NOT_SPECIFIED = 4000
+ USER_DISABLED = 4001
+ USER_NOT_FOUND = 4002
+
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""
diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py
index 006a6026..388bd0d6 100644
--- a/oasst-shared/oasst_shared/schemas/protocol.py
+++ b/oasst-shared/oasst_shared/schemas/protocol.py
@@ -392,6 +392,7 @@ class UserScore(BaseModel):
class LeaderboardStats(BaseModel):
time_frame: str
+ last_updated: datetime
leaderboard: List[UserScore]
diff --git a/oasst-shared/oasst_shared/utils.py b/oasst-shared/oasst_shared/utils.py
index 90ba8c8f..57cb66cc 100644
--- a/oasst-shared/oasst_shared/utils.py
+++ b/oasst-shared/oasst_shared/utils.py
@@ -10,14 +10,47 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)
+def unaware_to_utc(d: datetime | None) -> datetime:
+ """Set timezeno to UTC if datetime is unaware (tzinfo == None)."""
+ if d and d.tzinfo is None:
+ return d.replace(tzinfo=timezone.utc)
+ return d
+
+
+class TimerError(Exception):
+ """A custom exception used to report errors in use of Timer class"""
+
+
+class ScopeTimer:
+ def __init__(self):
+ self.start()
+
+ def start(self) -> None:
+ """Measure new start time"""
+ self.start_time = time.perf_counter()
+
+ def stop(self) -> float:
+ """Store and return the elapsed time"""
+ self.elapsed = time.perf_counter() - self.start_time
+ return self.elapsed
+
+ def __enter__(self):
+ """Start a new timer as a context manager"""
+ self.start()
+ return self
+
+ def __exit__(self, *exc_info):
+ """Stop the context manager timer"""
+ self.stop()
+
+
def log_timing(func=None, *, log_kwargs: bool = False, level: int | str = "DEBUG") -> None:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
- start = time.time()
+ timer = ScopeTimer()
result = func(*args, **kwargs)
- end = time.time()
- elapsed = end - start
+ elapsed = timer.stop()
if log_kwargs:
kwargs = ", ".join([f"{k}={v}" for k, v in kwargs.items()])
logger.log(level, f"Function '{func.__name__}({kwargs})' executed in {elapsed:f} s")
diff --git a/scripts/postprocessing/rankings.py b/scripts/postprocessing/rankings.py
index f6e7a31e..1df6df36 100644
--- a/scripts/postprocessing/rankings.py
+++ b/scripts/postprocessing/rankings.py
@@ -101,7 +101,7 @@ def ranked_pairs(ranks: List[List[int]]):
# order by strength of victory (using tideman's original method, don't think it would make a difference for us)
sorted_majorities = []
for i in range(len(ranks[0])):
- for j in range(len(ranks[i])):
+ for j in range(len(ranks[0])):
if tallies[i, j] > 0:
sorted_majorities.append((i, j, tallies[i, j]))
# we don't explicitly deal with tied majorities here
@@ -132,8 +132,8 @@ if __name__ == "__main__":
[("w", "x", "z", "y") for _ in range(1)]
+ [("w", "y", "x", "z") for _ in range(2)]
# + [("x","y","z","w") for _ in range(4)]
- + [("x", "z", "w", "y") for _ in range(5)]
- + [("y", "w", "x", "z") for _ in range(1)]
+ # + [("x", "z", "w", "y") for _ in range(5)]
+ # + [("y", "w", "x", "z") for _ in range(1)]
# [("y","z","w","x") for _ in range(1000)]
)
rp = ranked_pairs(ranks)
diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py
new file mode 100644
index 00000000..cea07c1e
--- /dev/null
+++ b/text-frontend/auto_main.py
@@ -0,0 +1,250 @@
+"""Simple REPL frontend."""
+
+import http
+import random
+
+import requests
+import typer
+
+app = typer.Typer()
+
+
+# debug constants
+USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
+
+
+def _random_message_id():
+ return str(random.randint(1000, 9999))
+
+
+def _render_message(message: dict) -> str:
+ """Render a message to the user."""
+ if message["is_assistant"]:
+ return f"Assistant: {message['text']}"
+ return f"Prompter: {message['text']}"
+
+
+@app.command()
+def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
+ """automates tasks"""
+
+ def _post(path: str, json: dict) -> dict:
+ response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
+ response.raise_for_status()
+ if response.status_code == http.HTTPStatus.NO_CONTENT:
+ return None
+ return response.json()
+
+ def gen_random_text():
+ return " ".join([random.choice(["hello", "world", "foo", "bar"]) for _ in range(10)])
+
+ def gen_random_ranking(messages):
+ """rank messages randomly and return list of indexes in order of rank randomly"""
+ print("Ranking")
+ print(messages)
+ print(len(messages))
+ ranks = [i for i in range(len(messages))]
+ shuffled = random.shuffle(ranks)
+ print(ranks)
+ print(shuffled)
+ return ranks
+
+ tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
+ q = 0
+ while tasks:
+ task = tasks.pop(0)
+ print(task)
+
+ match (task["type"]):
+ case "initial_prompt":
+ typer.echo("Please provide an initial prompt to the assistant.")
+ if task["hint"]:
+ typer.echo(f"Hint: {task['hint']}")
+ # acknowledge task
+ message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+
+ prompt = gen_random_text()
+ user_message_id = _random_message_id()
+ # send interaction
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "text_reply_to_message",
+ "message_id": message_id,
+ "task_id": task["id"],
+ "user_message_id": user_message_id,
+ "text": prompt,
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+
+ case "label_initial_prompt":
+ typer.echo("Label the following prompt:")
+ typer.echo(task["prompt"])
+ # acknowledge task
+ message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+
+ valid_labels = task["valid_labels"]
+
+ labels_dict = None
+ if task["mode"] == "simple" and len(valid_labels) == 1:
+ answer = random.choice([True, False])
+ labels_dict = {valid_labels[0]: 1 if answer else 0}
+ else:
+ while labels_dict is None:
+ labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
+
+ if all([label in valid_labels for label in labels]):
+ labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
+ else:
+ invalid_labels = [label for label in labels if label not in valid_labels]
+ typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
+
+ # send labels
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "text_labels",
+ "message_id": task["message_id"],
+ "task_id": task["id"],
+ "text": task["prompt"],
+ "labels": labels_dict,
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+ case "prompter_reply":
+ # acknowledge task
+ message_id = _random_message_id()
+ user_message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+ # send interaction
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "text_reply_to_message",
+ "message_id": message_id,
+ "task_id": task["id"],
+ "user_message_id": user_message_id,
+ "text": gen_random_text(),
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+
+ case "assistant_reply":
+ # acknowledge task
+ message_id = _random_message_id()
+ user_message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+ # send interaction
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "text_reply_to_message",
+ "message_id": message_id,
+ "task_id": task["id"],
+ "user_message_id": user_message_id,
+ "text": gen_random_text(),
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+
+ case "rank_prompter_replies" | "rank_assistant_replies":
+ # acknowledge task
+ message_id = _random_message_id()
+ user_message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+ # send interaction
+ ranking = gen_random_ranking(task["replies"])
+ print(ranking)
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "message_ranking",
+ "message_id": message_id,
+ "task_id": task["id"],
+ "ranking": ranking,
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+
+ case "rank_initial_prompts":
+ # acknowledge task
+ message_id = _random_message_id()
+ user_message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+ # send interaction
+ ranking = gen_random_ranking(task["prompots"])
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "message_ranking",
+ "message_id": message_id,
+ "ranking": ranking,
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+
+ case "label_prompter_reply" | "label_assistant_reply":
+ # acknowledge task
+ typer.echo("Here is the conversation so far:")
+ for message in task["conversation"]["messages"]:
+ typer.echo(_render_message(message))
+
+ typer.echo("Label the following reply:")
+ typer.echo(task["reply"])
+ message_id = _random_message_id()
+ user_message_id = _random_message_id()
+ _post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
+ valid_labels = task["valid_labels"]
+
+ labels_dict = None
+ if task["mode"] == "simple" and len(valid_labels) == 1:
+ answer = random.choice([True, False])
+ labels_dict = {valid_labels[0]: 1 if answer else 0}
+ else:
+ while labels_dict is None:
+ labels = random.sample(valid_labels, random.randint(1, len(valid_labels)))
+
+ if all([label in valid_labels for label in labels]):
+ labels_dict = {label: "1" if label in labels else "0" for label in valid_labels}
+ else:
+ invalid_labels = [label for label in labels if label not in valid_labels]
+ typer.echo(f"Invalid labels: {', '.join(invalid_labels)}. Valid: {', '.join(valid_labels)}")
+ # send interaction
+ new_task = _post(
+ "/api/v1/tasks/interaction",
+ {
+ "type": "text_labels",
+ "message_id": task["message_id"],
+ "task_id": task["id"],
+ "text": task["reply"],
+ "labels": labels_dict,
+ "user": USER,
+ },
+ )
+ tasks.append(new_task)
+ case "task_done":
+ typer.echo("Task done!")
+ # rerun with new task slected from above cases
+ # add a new task
+ q += 1
+ if q == 10:
+ typer.echo("Task done!")
+ break
+ tasks = [_post("/api/v1/tasks/", {"type": "random", "user": USER})]
+ #
+ case _:
+ typer.echo(f"Unknown task type {task['type']}")
+ # rerun with new task slected from above cases
+
+
+if __name__ == "__main__":
+ app()
diff --git a/website/cypress/e2e/auth/signin.cy.ts b/website/cypress/e2e/auth/signin.cy.ts
index 6d57d1f9..2a651f1f 100644
--- a/website/cypress/e2e/auth/signin.cy.ts
+++ b/website/cypress/e2e/auth/signin.cy.ts
@@ -27,13 +27,6 @@ describe("signin flow", () => {
});
});
});
- it("shows the logged in users email address if logged in with email", () => {
- const emailAddress = "user@example.com";
- cy.signInWithEmail(emailAddress);
- // The user will only see the email address if the window is wide enough, not technically required as even when hidden this will find it in the page.
- cy.viewport(1920, 1000);
- cy.contains('[data-cy="username"]', emailAddress);
- });
});
export {};
diff --git a/website/next-i18next.config.js b/website/next-i18next.config.js
new file mode 100644
index 00000000..7c87a7a4
--- /dev/null
+++ b/website/next-i18next.config.js
@@ -0,0 +1,6 @@
+module.exports = {
+ i18n: {
+ defaultLocale: "en",
+ locales: ["en"],
+ },
+};
diff --git a/website/next.config.js b/website/next.config.js
index 28da824f..a84ce736 100644
--- a/website/next.config.js
+++ b/website/next.config.js
@@ -1,4 +1,6 @@
/** @type {import('next').NextConfig} */
+const { i18n } = require("./next-i18next.config");
+
const nextConfig = {
output: "standalone",
reactStrictMode: true,
@@ -16,6 +18,7 @@ const nextConfig = {
*/
// scrollRestoration: true,
},
+ i18n,
};
module.exports = nextConfig;
diff --git a/website/package-lock.json b/website/package-lock.json
index 29cd0326..c60bf888 100644
--- a/website/package-lock.json
+++ b/website/package-lock.json
@@ -34,6 +34,7 @@
"install": "^0.13.0",
"next": "13.0.6",
"next-auth": "^4.18.6",
+ "next-i18next": "^13.0.3",
"nodemailer": "^6.8.0",
"npm": "^9.2.0",
"postcss-focus-visible": "^7.1.0",
@@ -41,11 +42,13 @@
"react-dom": "18.2.0",
"react-feature-flags": "^1.0.0",
"react-hook-form": "^7.42.1",
+ "react-i18next": "^12.1.4",
"react-icons": "^4.7.1",
"react-table": "^7.8.0",
"sharp": "^0.31.3",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
+ "unique-username-generator": "^1.1.3",
"use-debounce": "^9.0.2"
},
"devDependencies": {
@@ -12685,6 +12688,15 @@
"@types/unist": "*"
}
},
+ "node_modules/@types/hoist-non-react-statics": {
+ "version": "3.3.1",
+ "resolved": "https://registry.npmjs.org/@types/hoist-non-react-statics/-/hoist-non-react-statics-3.3.1.tgz",
+ "integrity": "sha512-iMIqiko6ooLrTh1joXodJK5X9xeEALT1kM5G3ZLhD3hszxBdIEd5C75U834D9mLcINgD4OyZf5uQXjkuYydWvA==",
+ "dependencies": {
+ "@types/react": "*",
+ "hoist-non-react-statics": "^3.3.0"
+ }
+ },
"node_modules/@types/html-minifier-terser": {
"version": "6.1.0",
"resolved": "https://registry.npmjs.org/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz",
@@ -12891,8 +12903,7 @@
"node_modules/@types/prop-types": {
"version": "15.7.5",
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.5.tgz",
- "integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==",
- "devOptional": true
+ "integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w=="
},
"node_modules/@types/qs": {
"version": "6.9.7",
@@ -12904,7 +12915,6 @@
"version": "18.0.26",
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.0.26.tgz",
"integrity": "sha512-hCR3PJQsAIXyxhTNSiDFY//LhnMZWpNNr5etoCqx/iUfGc5gXWtQR2Phl908jVR6uPXacojQWTg4qRpkxTuGug==",
- "devOptional": true,
"dependencies": {
"@types/prop-types": "*",
"@types/scheduler": "*",
@@ -12923,8 +12933,7 @@
"node_modules/@types/scheduler": {
"version": "0.16.2",
"resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.16.2.tgz",
- "integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew==",
- "devOptional": true
+ "integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew=="
},
"node_modules/@types/semver": {
"version": "7.3.13",
@@ -16582,7 +16591,6 @@
"version": "3.27.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.27.1.tgz",
"integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww==",
- "dev": true,
"hasInstallScript": true,
"funding": {
"type": "opencollective",
@@ -21336,6 +21344,14 @@
"node": ">= 12"
}
},
+ "node_modules/html-parse-stringify": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/html-parse-stringify/-/html-parse-stringify-3.0.1.tgz",
+ "integrity": "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==",
+ "dependencies": {
+ "void-elements": "3.1.0"
+ }
+ },
"node_modules/html-tags": {
"version": "3.2.0",
"resolved": "https://registry.npmjs.org/html-tags/-/html-tags-3.2.0.tgz",
@@ -21478,6 +21494,34 @@
"integrity": "sha512-iimHkHPfIAQ8zCDQLgn08pRqSVioyWvnGfaQ8gond2wf7Jq2jJ+24ykmnRyiz3fIldcn4oUuQXpjqKLhSVR7lw==",
"dev": true
},
+ "node_modules/i18next": {
+ "version": "22.4.9",
+ "resolved": "https://registry.npmjs.org/i18next/-/i18next-22.4.9.tgz",
+ "integrity": "sha512-8gWMmUz460KJDQp/ob3MNUX84cVuDRY9PLFPnV8d+Qezz/6dkjxwOaH70xjrCNDO+JrUL25iXfAIN9wUkInNZw==",
+ "funding": [
+ {
+ "type": "individual",
+ "url": "https://locize.com"
+ },
+ {
+ "type": "individual",
+ "url": "https://locize.com/i18next.html"
+ },
+ {
+ "type": "individual",
+ "url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
+ }
+ ],
+ "peer": true,
+ "dependencies": {
+ "@babel/runtime": "^7.20.6"
+ }
+ },
+ "node_modules/i18next-fs-backend": {
+ "version": "2.1.1",
+ "resolved": "https://registry.npmjs.org/i18next-fs-backend/-/i18next-fs-backend-2.1.1.tgz",
+ "integrity": "sha512-FTnj+UmNgT3YRml5ruRv0jMZDG7odOL/OP5PF5mOqvXud2vHrPOOs68Zdk6iqzL47cnnM0ZVkK2BAvpFeDJToA=="
+ },
"node_modules/iconv-lite": {
"version": "0.4.24",
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
@@ -27555,6 +27599,45 @@
}
}
},
+ "node_modules/next-i18next": {
+ "version": "13.0.3",
+ "resolved": "https://registry.npmjs.org/next-i18next/-/next-i18next-13.0.3.tgz",
+ "integrity": "sha512-7AA8J6WbkxRBtSf1+97LSAE7btxWZHsBIJEJ3FuTSBgYtpRiO5NGjcb8XbNAlz6yGU0TtS+yZE+/Wu83KhIT1Q==",
+ "funding": [
+ {
+ "type": "individual",
+ "url": "https://locize.com/i18next.html"
+ },
+ {
+ "type": "individual",
+ "url": "https://www.i18next.com/how-to/faq#i18next-is-awesome.-how-can-i-support-the-project"
+ },
+ {
+ "type": "individual",
+ "url": "https://locize.com"
+ },
+ {
+ "type": "individual",
+ "url": "https://github.com/belgattitude"
+ }
+ ],
+ "dependencies": {
+ "@babel/runtime": "^7.20.6",
+ "@types/hoist-non-react-statics": "^3.3.1",
+ "core-js": "^3",
+ "hoist-non-react-statics": "^3.3.2",
+ "i18next-fs-backend": "^2.1.0"
+ },
+ "engines": {
+ "node": ">=14"
+ },
+ "peerDependencies": {
+ "i18next": "^22.0.6",
+ "next": ">= 12.0.0",
+ "react": ">= 17.0.2",
+ "react-i18next": "^12.1.1"
+ }
+ },
"node_modules/next/node_modules/postcss": {
"version": "8.4.14",
"resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.14.tgz",
@@ -32529,6 +32612,27 @@
"react": "^16.8.0 || ^17 || ^18"
}
},
+ "node_modules/react-i18next": {
+ "version": "12.1.4",
+ "resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-12.1.4.tgz",
+ "integrity": "sha512-XQND7jYtgM7ht5PH3yIZljCRpAMTlH/zmngM9ZjToqa+0BR6xuu8c7QF0WIIOEjcMTB2S3iOfpN/xG/ZrAnO6g==",
+ "dependencies": {
+ "@babel/runtime": "^7.20.6",
+ "html-parse-stringify": "^3.0.1"
+ },
+ "peerDependencies": {
+ "i18next": ">= 19.0.0",
+ "react": ">= 16.8.0"
+ },
+ "peerDependenciesMeta": {
+ "react-dom": {
+ "optional": true
+ },
+ "react-native": {
+ "optional": true
+ }
+ }
+ },
"node_modules/react-icons": {
"version": "4.7.1",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
@@ -36019,6 +36123,11 @@
"imurmurhash": "^0.1.4"
}
},
+ "node_modules/unique-username-generator": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
+ "integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
+ },
"node_modules/unist-builder": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
@@ -36539,6 +36648,14 @@
"integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==",
"dev": true
},
+ "node_modules/void-elements": {
+ "version": "3.1.0",
+ "resolved": "https://registry.npmjs.org/void-elements/-/void-elements-3.1.0.tgz",
+ "integrity": "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w==",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
"node_modules/w3c-xmlserializer": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-4.0.0.tgz",
@@ -46906,6 +47023,15 @@
"@types/unist": "*"
}
},
+ "@types/hoist-non-react-statics": {
+ "version": "3.3.1",
+ "resolved": "https://registry.npmjs.org/@types/hoist-non-react-statics/-/hoist-non-react-statics-3.3.1.tgz",
+ "integrity": "sha512-iMIqiko6ooLrTh1joXodJK5X9xeEALT1kM5G3ZLhD3hszxBdIEd5C75U834D9mLcINgD4OyZf5uQXjkuYydWvA==",
+ "requires": {
+ "@types/react": "*",
+ "hoist-non-react-statics": "^3.3.0"
+ }
+ },
"@types/html-minifier-terser": {
"version": "6.1.0",
"resolved": "https://registry.npmjs.org/@types/html-minifier-terser/-/html-minifier-terser-6.1.0.tgz",
@@ -47098,8 +47224,7 @@
"@types/prop-types": {
"version": "15.7.5",
"resolved": "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.5.tgz",
- "integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w==",
- "devOptional": true
+ "integrity": "sha512-JCB8C6SnDoQf0cNycqd/35A7MjcnK+ZTqE7judS6o7utxUCg6imJg3QK2qzHKszlTjcj2cn+NwMB2i96ubpj7w=="
},
"@types/qs": {
"version": "6.9.7",
@@ -47111,7 +47236,6 @@
"version": "18.0.26",
"resolved": "https://registry.npmjs.org/@types/react/-/react-18.0.26.tgz",
"integrity": "sha512-hCR3PJQsAIXyxhTNSiDFY//LhnMZWpNNr5etoCqx/iUfGc5gXWtQR2Phl908jVR6uPXacojQWTg4qRpkxTuGug==",
- "devOptional": true,
"requires": {
"@types/prop-types": "*",
"@types/scheduler": "*",
@@ -47130,8 +47254,7 @@
"@types/scheduler": {
"version": "0.16.2",
"resolved": "https://registry.npmjs.org/@types/scheduler/-/scheduler-0.16.2.tgz",
- "integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew==",
- "devOptional": true
+ "integrity": "sha512-hppQEBDmlwhFAXKJX2KnWLYu5yMfi91yazPb2l+lbJiwW+wdo1gNeRA+3RgNSO39WYX2euey41KEwnqesU2Jew=="
},
"@types/semver": {
"version": "7.3.13",
@@ -50004,8 +50127,7 @@
"core-js": {
"version": "3.27.1",
"resolved": "https://registry.npmjs.org/core-js/-/core-js-3.27.1.tgz",
- "integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww==",
- "dev": true
+ "integrity": "sha512-GutwJLBChfGCpwwhbYoqfv03LAfmiz7e7D/BNxzeMxwQf10GRSzqiOjx7AmtEk+heiD/JWmBuyBPgFtx0Sg1ww=="
},
"core-js-compat": {
"version": "3.27.1",
@@ -53725,6 +53847,14 @@
}
}
},
+ "html-parse-stringify": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/html-parse-stringify/-/html-parse-stringify-3.0.1.tgz",
+ "integrity": "sha512-KknJ50kTInJ7qIScF3jeaFRpMpE8/lfiTdzf/twXyPBLAGrLRTmkz3AdTnKeh40X8k9L2fdYwEp/42WGXIRGcg==",
+ "requires": {
+ "void-elements": "3.1.0"
+ }
+ },
"html-tags": {
"version": "3.2.0",
"resolved": "https://registry.npmjs.org/html-tags/-/html-tags-3.2.0.tgz",
@@ -53825,6 +53955,20 @@
"integrity": "sha512-iimHkHPfIAQ8zCDQLgn08pRqSVioyWvnGfaQ8gond2wf7Jq2jJ+24ykmnRyiz3fIldcn4oUuQXpjqKLhSVR7lw==",
"dev": true
},
+ "i18next": {
+ "version": "22.4.9",
+ "resolved": "https://registry.npmjs.org/i18next/-/i18next-22.4.9.tgz",
+ "integrity": "sha512-8gWMmUz460KJDQp/ob3MNUX84cVuDRY9PLFPnV8d+Qezz/6dkjxwOaH70xjrCNDO+JrUL25iXfAIN9wUkInNZw==",
+ "peer": true,
+ "requires": {
+ "@babel/runtime": "^7.20.6"
+ }
+ },
+ "i18next-fs-backend": {
+ "version": "2.1.1",
+ "resolved": "https://registry.npmjs.org/i18next-fs-backend/-/i18next-fs-backend-2.1.1.tgz",
+ "integrity": "sha512-FTnj+UmNgT3YRml5ruRv0jMZDG7odOL/OP5PF5mOqvXud2vHrPOOs68Zdk6iqzL47cnnM0ZVkK2BAvpFeDJToA=="
+ },
"iconv-lite": {
"version": "0.4.24",
"resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.4.24.tgz",
@@ -58442,6 +58586,18 @@
"uuid": "^8.3.2"
}
},
+ "next-i18next": {
+ "version": "13.0.3",
+ "resolved": "https://registry.npmjs.org/next-i18next/-/next-i18next-13.0.3.tgz",
+ "integrity": "sha512-7AA8J6WbkxRBtSf1+97LSAE7btxWZHsBIJEJ3FuTSBgYtpRiO5NGjcb8XbNAlz6yGU0TtS+yZE+/Wu83KhIT1Q==",
+ "requires": {
+ "@babel/runtime": "^7.20.6",
+ "@types/hoist-non-react-statics": "^3.3.1",
+ "core-js": "^3",
+ "hoist-non-react-statics": "^3.3.2",
+ "i18next-fs-backend": "^2.1.0"
+ }
+ },
"nice-try": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/nice-try/-/nice-try-1.0.5.tgz",
@@ -61933,6 +62089,15 @@
"integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==",
"requires": {}
},
+ "react-i18next": {
+ "version": "12.1.4",
+ "resolved": "https://registry.npmjs.org/react-i18next/-/react-i18next-12.1.4.tgz",
+ "integrity": "sha512-XQND7jYtgM7ht5PH3yIZljCRpAMTlH/zmngM9ZjToqa+0BR6xuu8c7QF0WIIOEjcMTB2S3iOfpN/xG/ZrAnO6g==",
+ "requires": {
+ "@babel/runtime": "^7.20.6",
+ "html-parse-stringify": "^3.0.1"
+ }
+ },
"react-icons": {
"version": "4.7.1",
"resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz",
@@ -64641,6 +64806,11 @@
"imurmurhash": "^0.1.4"
}
},
+ "unique-username-generator": {
+ "version": "1.1.3",
+ "resolved": "https://registry.npmjs.org/unique-username-generator/-/unique-username-generator-1.1.3.tgz",
+ "integrity": "sha512-TB6YdqPMKMpTSgxAzjZkKWtmpZPHvARoWreCKBpc1UrLFz/0C6Q96/qdjpLr9OXPCHk16sD1LHjTr3JDj7q2JA=="
+ },
"unist-builder": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/unist-builder/-/unist-builder-2.0.3.tgz",
@@ -65014,6 +65184,11 @@
"integrity": "sha512-2ham8XPWTONajOR0ohOKOHXkm3+gaBmGut3SRuu75xLd/RRaY6vqgh8NBYYk7+RW3u5AtzPQZG8F10LHkl0lAQ==",
"dev": true
},
+ "void-elements": {
+ "version": "3.1.0",
+ "resolved": "https://registry.npmjs.org/void-elements/-/void-elements-3.1.0.tgz",
+ "integrity": "sha512-Dhxzh5HZuiHQhbvTW9AMetFfBHDMYpo23Uo9btPXgdYP+3T5S+p+jgNy7spra+veYhBP2dCSgxR/i2Y02h5/6w=="
+ },
"w3c-xmlserializer": {
"version": "4.0.0",
"resolved": "https://registry.npmjs.org/w3c-xmlserializer/-/w3c-xmlserializer-4.0.0.tgz",
diff --git a/website/package.json b/website/package.json
index 6dcbb26a..24a77f28 100644
--- a/website/package.json
+++ b/website/package.json
@@ -51,6 +51,7 @@
"install": "^0.13.0",
"next": "13.0.6",
"next-auth": "^4.18.6",
+ "next-i18next": "^13.0.3",
"nodemailer": "^6.8.0",
"npm": "^9.2.0",
"postcss-focus-visible": "^7.1.0",
@@ -58,11 +59,13 @@
"react-dom": "18.2.0",
"react-feature-flags": "^1.0.0",
"react-hook-form": "^7.42.1",
+ "react-i18next": "^12.1.4",
"react-icons": "^4.7.1",
"react-table": "^7.8.0",
"sharp": "^0.31.3",
"swr": "^2.0.0",
"tailwindcss": "^3.2.4",
+ "unique-username-generator": "^1.1.3",
"use-debounce": "^9.0.2"
},
"devDependencies": {
diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json
new file mode 100644
index 00000000..0b2df79c
--- /dev/null
+++ b/website/public/locales/en/common.json
@@ -0,0 +1,4 @@
+{
+ "discord": "Discord",
+ "github": "GitHub"
+}
diff --git a/website/public/locales/en/index.json b/website/public/locales/en/index.json
new file mode 100644
index 00000000..3443e444
--- /dev/null
+++ b/website/public/locales/en/index.json
@@ -0,0 +1,16 @@
+{
+ "title": "Open Assistant",
+ "subtitle": "Conversational AI for everyone.",
+ "description": "Conversational AI for everyone. An open source project to create a chat enabled GPT LLM run by LAION and contributors around the world.",
+ "blurb": "We believe we can create a revolution.",
+ "blurb1": "In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve the world by providing amazing conversational AI.",
+ "join_us_title": "Join us",
+ "join_us_description": "All open source projects begin with people like you. Open source is the belief that if we collaborate we can together gift our knowledge and technology to the world for the benefit of humanity. Are you in? Find us here:",
+ "faq_title": "Frequently Asked Questions",
+ "faq_items": {
+ "q0": "How far along is this project?",
+ "a0": "We are in the early stages of development, working from established research in applying RLHF to large language models.",
+ "q1": "Who is behind Open Assistant?",
+ "a1": "Open Assistant is a project organized by LAION and individuals around the world interested in bringing this technology to everyone."
+ }
+}
diff --git a/website/src/components/AnimatedCircles/AnimatedCircles.tsx b/website/src/components/AnimatedCircles/AnimatedCircles.tsx
new file mode 100644
index 00000000..6241f7a6
--- /dev/null
+++ b/website/src/components/AnimatedCircles/AnimatedCircles.tsx
@@ -0,0 +1,52 @@
+import { Box, useColorMode } from "@chakra-ui/react";
+import React, { useId } from "react";
+
+export const AnimatedCircles = () => {
+ const id = useId();
+ const { colorMode } = useColorMode();
+ const baseRingColor = colorMode === "light" ? "#d4d4d4" : "#005a69";
+ const gradStopColor = colorMode === "light" ? "#06b6d4" : "#00f2ff";
+
+ return (
+
+
+
+
+ );
+};
diff --git a/website/src/components/AnimatedCircles/index.tsx b/website/src/components/AnimatedCircles/index.tsx
new file mode 100644
index 00000000..a5ca582a
--- /dev/null
+++ b/website/src/components/AnimatedCircles/index.tsx
@@ -0,0 +1 @@
+export { AnimatedCircles } from "./AnimatedCircles";
diff --git a/website/src/components/CallToAction.tsx b/website/src/components/CallToAction.tsx
index 8a07373f..e374a471 100644
--- a/website/src/components/CallToAction.tsx
+++ b/website/src/components/CallToAction.tsx
@@ -1,9 +1,14 @@
-import { useColorMode } from "@chakra-ui/react";
+import { Box, Link, Text, useColorMode } from "@chakra-ui/react";
+import { useTranslation } from "next-i18next";
import { useId } from "react";
+import { FaDiscord, FaGithub } from "react-icons/fa";
import { Container } from "./Container";
-function CircleBackground({ width = 558, height = 558, ...props }) {
+const CIRCLE_HEIGHT = 558;
+const CIRCLE_WIDTH = 558;
+
+function CircleBackground() {
const id = useId();
const { colorMode } = useColorMode();
@@ -11,7 +16,14 @@ function CircleBackground({ width = 558, height = 558, ...props }) {
const gradStopColor = colorMode === "light" ? "#fff" : "#000";
return (
-
-
+
);
}
diff --git a/website/src/components/Dashboard/SlimFooter.tsx b/website/src/components/Dashboard/SlimFooter.tsx
index 5a7b093c..1c109eb4 100644
--- a/website/src/components/Dashboard/SlimFooter.tsx
+++ b/website/src/components/Dashboard/SlimFooter.tsx
@@ -20,6 +20,7 @@ export function SlimFooter() {
+
diff --git a/website/src/components/EmptyState.tsx b/website/src/components/EmptyState.tsx
index 8d82163c..14715518 100644
--- a/website/src/components/EmptyState.tsx
+++ b/website/src/components/EmptyState.tsx
@@ -28,7 +28,3 @@ export const EmptyState = (props: EmptyStateProps) => {
export const TaskEmptyState = () => {
return ;
};
-
-export const PageEmptyState = () => {
- return ;
-};
diff --git a/website/src/components/Faq.tsx b/website/src/components/Faq.tsx
index b8e5e8f8..55bb3585 100644
--- a/website/src/components/Faq.tsx
+++ b/website/src/components/Faq.tsx
@@ -1,73 +1,42 @@
-import { useColorMode } from "@chakra-ui/react";
+import { Box, List, ListItem, Text, useColorMode } from "@chakra-ui/react";
+import { useTranslation } from "next-i18next";
import { Container } from "./Container";
-const faqs = [
- [
- {
- question: "How far along is this project?",
- answer:
- "We are in the early stages of development, working from established research in applying RLHF to large language models.",
- },
- ],
- [
- {
- question: "Who is behind Open Assistant?",
- answer:
- "Open Assistant is a project organized by LAION and individuals around the world interested in bringing this technology to everyone.",
- },
- ],
- [
- // {
- // question: 'Where can I learn more?',
- // answer:
- // 'Please feel free to reach out to us on Discord. We are happy to answer any questions you may have.',
- // },
- ],
-];
+const FAQS = Array.from({ length: 2 });
export function Faq() {
const { colorMode } = useColorMode();
-
+ const { t } = useTranslation("index");
const headingColorClass = colorMode === "light" ? "text-gray-900" : "text-white";
const textColorClass = colorMode === "light" ? "text-gray-700" : "text-gray-100";
return (
-
+
-
-
- Frequently Asked Questions
-
- {/*
- If you have anything else you want to ask,{' '}
-
- reach out to us
-
- .
-
+
+
+ {t("title")}
+
+
- Conversational AI for everyone.
-
-
We believe we can create a revolution.
-
- In the same way that Stable Diffusion helped the world make art and images in new ways, we want to improve
- the world by providing amazing conversational AI.
-
+
);
}
diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx
index 2af60bbe..8e9d03b6 100644
--- a/website/src/components/Messages/MessageTableEntry.tsx
+++ b/website/src/components/Messages/MessageTableEntry.tsx
@@ -30,7 +30,12 @@ export function MessageTableEntry(props: MessageTableEntryProps) {
{props.enabled ? (
-
+
{item.text}
diff --git a/website/src/components/Survey/LabelRadioGroup.tsx b/website/src/components/Survey/LabelRadioGroup.tsx
index bf2521f6..c4a5a51c 100644
--- a/website/src/components/Survey/LabelRadioGroup.tsx
+++ b/website/src/components/Survey/LabelRadioGroup.tsx
@@ -24,7 +24,7 @@ interface LabelRadioGroupProps {
const label_messages: { [label: string]: { description: string; explanation: string[] } } = {
spam: {
- description: "The message is spam?",
+ description: "Is the message spam?",
explanation: [
'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".',
"This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.",
diff --git a/website/src/components/Survey/LabelSliderGroup.tsx b/website/src/components/Survey/LabelSliderGroup.tsx
index af75281d..1c3b29b5 100644
--- a/website/src/components/Survey/LabelSliderGroup.tsx
+++ b/website/src/components/Survey/LabelSliderGroup.tsx
@@ -59,8 +59,8 @@ function CheckboxSliderItem(props: {
>
-
+
>
);
diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx
index 79d081ce..6cbead52 100644
--- a/website/src/components/Tasks/CreateTask.tsx
+++ b/website/src/components/Tasks/CreateTask.tsx
@@ -4,6 +4,7 @@ import { MessageTable } from "src/components/Messages/MessageTable";
import { TrackedTextarea } from "src/components/Survey/TrackedTextarea";
import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards";
import { TaskSurveyProps } from "src/components/Tasks/Task";
+import { TaskHeader } from "src/components/Tasks/TaskHeader";
export const CreateTask = ({
task,
@@ -14,7 +15,6 @@ export const CreateTask = ({
}: TaskSurveyProps<{ text: string }>) => {
const cardColor = useColorModeValue("gray.50", "gray.800");
const titleColor = useColorModeValue("gray.800", "gray.300");
- const labelColor = useColorModeValue("gray.600", "gray.400");
const [inputText, setInputText] = useState("");
const textChangeHandler = (event: React.ChangeEvent) => {
@@ -33,14 +33,7 @@ export const CreateTask = ({