diff --git a/.github/workflows/deploy-to-node.yaml b/.github/workflows/deploy-to-node.yaml index a88c09aa..f107d0af 100644 --- a/.github/workflows/deploy-to-node.yaml +++ b/.github/workflows/deploy-to-node.yaml @@ -33,6 +33,9 @@ jobs: WEB_EMAIL_SERVER_PORT: ${{ secrets.DEV_WEB_EMAIL_SERVER_PORT }} WEB_EMAIL_SERVER_USER: ${{ secrets.DEV_WEB_EMAIL_SERVER_USER }} WEB_NEXTAUTH_SECRET: ${{ secrets.NEXTAUTH_SECRET }} + S3_BUCKET_NAME: ${{ secrets.S3_BUCKET_NAME }} + AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} + AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} steps: - name: Checkout uses: actions/checkout@v2 diff --git a/ansible/deploy-to-node.yaml b/ansible/deploy-to-node.yaml index defff549..94746437 100644 --- a/ansible/deploy-to-node.yaml +++ b/ansible/deploy-to-node.yaml @@ -78,6 +78,29 @@ - name: backend - name: web + - name: Copy pgbackrest.conf to managed node + ansible.builtin.copy: + src: ./pgbackrest.conf + dest: "./{{ stack_name }}/pgbackrest.conf" + mode: 0644 + + - name: Create pgbackrest container + community.docker.docker_container: + name: "oasst-{{ stack_name }}-pgbackrest" + image: woblerr/pgbackrest:2.43 + state: "{{ 'stopped' if stack_name == 'production' else 'absent' }}" + network_mode: "oasst-{{ stack_name }}" + volumes: + - "./{{ stack_name }}/pgbackrest.conf:/etc/pgbackrest/pgbackrest.conf" + - "oasst-{{ stack_name }}-postgres-backend:/var/lib/postgresql/data" + env: + PGBACKREST_REPO1_S3_BUCKET: + "{{ lookup('ansible.builtin.env', 'S3_BUCKET_NAME') }}" + PGBACKREST_REPO1_S3_KEY: + "{{ lookup('ansible.builtin.env', 'AWS_ACCESS_KEY') }}" + PGBACKREST_REPO1_S3_KEY_SECRET: + "{{ lookup('ansible.builtin.env', 'AWS_SECRET_KEY') }}" + - name: Run the oasst oasst-backend community.docker.docker_container: name: "oasst-{{ stack_name }}-backend" @@ -136,7 +159,9 @@ FASTAPI_KEY: "{{ web_api_key }}" NEXTAUTH_SECRET: "{{ lookup('ansible.builtin.env', 'WEB_NEXTAUTH_SECRET') }}" - NEXTAUTH_URL: http://web.{{ stack_name }}.open-assistant.io/ + NEXTAUTH_URL: + "{{ 'https://open-assistant.io/' if stack_name == 'production' else + ('https://web.' + stack_name + '.open-assistant.io/') }}" ports: - "{{ website_port }}:3000" command: bash wait-for-postgres.sh node server.js diff --git a/ansible/pgbackrest.conf b/ansible/pgbackrest.conf new file mode 100644 index 00000000..036826d3 --- /dev/null +++ b/ansible/pgbackrest.conf @@ -0,0 +1,24 @@ +[oasst] +pg1-path=/var/lib/postgresql/data + +[global] +repo1-retention-full=3 # keep last 3 backups +repo1-type=s3 +repo1-path=/oasst-prod +repo1-s3-region=us-east-1 +repo1-s3-endpoint=s3.amazonaws.com +# repo1-s3-bucket=$S3_BUCKET_NAME +# repo1-s3-key=$AWS_ACCESS_KEY +# repo1-s3-key-secret=$AWS_SECRET_KEY + +# Force a checkpoint to start backup immediately. +start-fast=y +# Use delta restore. +delta=y + +# Enable ZSTD compression. +compress-type=zst +compress-level=6 + +log-level-console=info +log-level-file=debug diff --git a/backend/main.py b/backend/main.py index 9c900f8b..ce316f41 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,6 +20,7 @@ from oasst_backend.models import message_tree_state from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -120,7 +121,8 @@ if settings.RATE_LIMIT: if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") - def seed_data(): + @managed_tx_function(auto_commit=CommitMode.COMMIT) + def create_seed_data(session: Session): class DummyMessage(BaseModel): task_message_id: str user_message_id: str @@ -134,73 +136,73 @@ if settings.DEBUG_USE_SEED_DATA: try: logger.info("Seed data check began") - with Session(engine) as db: - api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) - dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - ur = UserRepository(db=db, api_client=api_client) - tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur) - pr = PromptRepository( - db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr - ) - tm = TreeManager(db, pr) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=session) + dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: - dummy_messages_raw = json.load(f) + ur = UserRepository(db=session, api_client=api_client) + tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur) + pr = PromptRepository( + db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr + ) + tm = TreeManager(session, pr) - dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] + with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: + dummy_messages_raw = json.load(f) - for msg in dummy_messages: - task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) - if task and not task.ack: - logger.warning("Deleting unacknowledged seed data task") - db.delete(task) - task = None - if not task: - if msg.parent_message_id is None: - task = tr.store_task( - protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None - ) - else: - parent_message = pr.fetch_message_by_frontend_message_id( - msg.parent_message_id, fail_if_missing=True - ) - conversation_messages = pr.fetch_message_conversation(parent_message) - conversation = prepare_conversation(conversation_messages) - if msg.role == "assistant": - task = tr.store_task( - protocol_schema.AssistantReplyTask(conversation=conversation), - message_tree_id=parent_message.message_tree_id, - parent_message_id=parent_message.id, - ) - else: - task = tr.store_task( - protocol_schema.PrompterReplyTask(conversation=conversation), - message_tree_id=parent_message.message_tree_id, - parent_message_id=parent_message.id, - ) - tr.bind_frontend_message_id(task.id, msg.task_message_id) - message = pr.store_text_reply( - msg.text, - msg.task_message_id, - msg.user_message_id, - review_count=5, - review_result=True, - check_tree_state=False, - ) - if message.parent_id is None: - tm._insert_default_state( - root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING - ) - db.commit() + dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] - logger.info( - f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" + for msg in dummy_messages: + task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) + if task and not task.ack: + logger.warning("Deleting unacknowledged seed data task") + session.delete(task) + task = None + if not task: + if msg.parent_message_id is None: + task = tr.store_task( + protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: - logger.debug(f"seed data task found: {task.id}") + parent_message = pr.fetch_message_by_frontend_message_id( + msg.parent_message_id, fail_if_missing=True + ) + conversation_messages = pr.fetch_message_conversation(parent_message) + conversation = prepare_conversation(conversation_messages) + if msg.role == "assistant": + task = tr.store_task( + protocol_schema.AssistantReplyTask(conversation=conversation), + message_tree_id=parent_message.message_tree_id, + parent_message_id=parent_message.id, + ) + else: + task = tr.store_task( + protocol_schema.PrompterReplyTask(conversation=conversation), + message_tree_id=parent_message.message_tree_id, + parent_message_id=parent_message.id, + ) + tr.bind_frontend_message_id(task.id, msg.task_message_id) + message = pr.store_text_reply( + msg.text, + msg.task_message_id, + msg.user_message_id, + review_count=5, + review_result=True, + check_tree_state=False, + ) + if message.parent_id is None: + tm._insert_default_state( + root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING + ) + session.flush() - logger.info("Seed data check completed") + logger.info( + f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" + ) + else: + logger.debug(f"seed data task found: {task.id}") + + logger.info("Seed data check completed") except Exception: logger.exception("Seed data insertion failed") @@ -220,48 +222,44 @@ def ensure_tree_states(): @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False) -def update_leader_board_day() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_day(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.day) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.day) except Exception: logger.exception("Error during leaderboard update (daily)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False) -def update_leader_board_week() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_week(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.week) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.week) except Exception: logger.exception("Error during user states update (weekly)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False) -def update_leader_board_month() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_month(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.month) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.month) except Exception: logger.exception("Error during user states update (monthly)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False) -def update_leader_board_total() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_total(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.total) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.total) except Exception: logger.exception("Error during user states update (total)") diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 17df814f..c65500fb 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional from uuid import UUID from fastapi import APIRouter, Depends @@ -48,6 +48,27 @@ def request_task( return task +@router.post("/availability", response_model=dict[protocol_schema.TaskRequestType, int]) +def tasks_availability( + *, + user: Optional[protocol_schema.User] = None, + db: Session = Depends(deps.get_db), + api_key: APIKey = Depends(deps.get_api_key), +): + api_client = deps.api_auth(api_key, db) + + try: + pr = PromptRepository(db, api_client, client_user=user) + tm = TreeManager(db, pr) + return tm.determine_task_availability() + + except OasstError: + raise + except Exception: + logger.exception("Task availability query failed.") + raise OasstError("Task availability query failed.", OasstErrorCode.TASK_AVAILABILITY_QUERY_FAILED) + + @router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT) def tasks_acknowledge( *, diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index b0c6b7f5..99b10cb4 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -55,6 +55,8 @@ class TreeManagerConfiguration(BaseModel): mandatory_labels_prompter_reply: Optional[list[protocol_schema.TextLabel]] = [protocol_schema.TextLabel.spam] """Mandatory labels in text-labeling tasks for prompter replies.""" + rank_prompter_replies: bool = False + class Settings(BaseSettings): PROJECT_NAME: str = "open-assistant backend" @@ -67,6 +69,7 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str = "postgres" POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + DATABASE_MAX_TX_RETRY_COUNT: int = 3 RATE_LIMIT: bool = True REDIS_HOST: str = "localhost" @@ -79,6 +82,7 @@ class Settings(BaseSettings): DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False DEBUG_SKIP_TOXICITY_CALCULATION: bool = False + DEBUG_DATABASE_ECHO: bool = False HUGGING_FACE_API_KEY: str = "" diff --git a/backend/oasst_backend/database.py b/backend/oasst_backend/database.py index b160da61..1d0e19b2 100644 --- a/backend/oasst_backend/database.py +++ b/backend/oasst_backend/database.py @@ -5,4 +5,4 @@ from sqlmodel import create_engine if settings.DATABASE_URI is None: raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET) -engine = create_engine(settings.DATABASE_URI) +engine = create_engine(settings.DATABASE_URI, echo=settings.DEBUG_DATABASE_ECHO, isolation_level="REPEATABLE READ") diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py index 67892ded..b39b498d 100644 --- a/backend/oasst_backend/journal_writer.py +++ b/backend/oasst_backend/journal_writer.py @@ -4,6 +4,7 @@ from uuid import UUID from oasst_backend.models import ApiClient, Journal, Task, User from oasst_backend.models.payload_column_type import PayloadContainer, payload_type +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.utils import utcnow from pydantic import BaseModel from sqlmodel import Session @@ -80,6 +81,7 @@ class JournalWriter: message_id=message_id, ) + @managed_tx_method(CommitMode.FLUSH) def log( self, *, @@ -115,7 +117,4 @@ class JournalWriter: ) self.db.add(entry) - if commit: - self.db.commit() - return entry diff --git a/backend/oasst_backend/models/payload_column_type.py b/backend/oasst_backend/models/payload_column_type.py index 01b642e2..132a7a78 100644 --- a/backend/oasst_backend/models/payload_column_type.py +++ b/backend/oasst_backend/models/payload_column_type.py @@ -48,6 +48,8 @@ def payload_column_type(pydantic_type): class PayloadJSONBType(TypeDecorator, Generic[T]): impl = pg.JSONB() + cache_ok = True + def __init__( self, json_encoder=json, diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6483cdc2..7f51ea19 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -24,6 +24,7 @@ from oasst_backend.models import ( from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats @@ -67,6 +68,7 @@ class PromptRepository: ) return message + @managed_tx_method(CommitMode.FLUSH) def insert_message( self, *, @@ -104,8 +106,8 @@ class PromptRepository: review_result=review_result, ) self.db.add(message) - self.db.commit() - self.db.refresh(message) + + # self.db.refresh(message) return message def _validate_task( @@ -134,6 +136,7 @@ class PromptRepository: def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState: return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + @managed_tx_method(CommitMode.FLUSH) def store_text_reply( self, text: str, @@ -205,10 +208,10 @@ class PromptRepository: if not task.collective: task.done = True self.db.add(task) - self.db.commit() self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text)) return user_message + @managed_tx_method(CommitMode.FLUSH) def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) @@ -238,6 +241,7 @@ class PromptRepository: logger.info(f"Ranking {rating.rating} stored for task {task.id}.") return reaction + @managed_tx_method(CommitMode.COMMIT) def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]: # fetch task task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id) @@ -310,6 +314,7 @@ class PromptRepository: return reaction, task + @managed_tx_method(CommitMode.FLUSH) def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity: """Save the toxicity score of a new message in the database. Args: @@ -325,10 +330,9 @@ class PromptRepository: message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label) self.db.add(message_toxicity) - self.db.commit() - self.db.refresh(message_toxicity) return message_toxicity + @managed_tx_method(CommitMode.FLUSH) def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: """Insert the embedding of a new message in the database. @@ -346,10 +350,9 @@ class PromptRepository: message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) self.db.add(message_embedding) - self.db.commit() - self.db.refresh(message_embedding) return message_embedding + @managed_tx_method(CommitMode.FLUSH) def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) @@ -363,10 +366,9 @@ class PromptRepository: payload_type=type(payload).__name__, ) self.db.add(reaction) - self.db.commit() - self.db.refresh(reaction) return reaction + @managed_tx_method(CommitMode.FLUSH) def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]: valid_labels: Optional[list[str]] = None @@ -436,8 +438,6 @@ class PromptRepository: self.db.add(message) self.db.add(model) - self.db.commit() - self.db.refresh(model) return model, task, message def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]: @@ -702,6 +702,7 @@ class PromptRepository: return messages.all() + @managed_tx_method(CommitMode.COMMIT) def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): """ Marks deleted messages and all their descendants. @@ -730,8 +731,6 @@ class PromptRepository: parent_ids = self.db.execute(query).scalars().all() - self.db.commit() - def get_stats(self) -> SystemStats: """ Get data stats such as number of all messages in the system, diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index acf48182..eb100fe3 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -6,6 +6,7 @@ from loguru import logger from oasst_backend.models import ApiClient, Task from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -128,6 +129,7 @@ class TaskRepository: assert task_model.id == task.id return task_model + @managed_tx_method(CommitMode.COMMIT) def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): validate_frontend_message_id(frontend_message_id) @@ -142,10 +144,9 @@ class TaskRepository: task.frontend_message_id = frontend_message_id task.ack = True - # ToDo: check race-condition, transaction self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): """ Mark task as done. No further messages will be accepted for this task. @@ -166,8 +167,8 @@ class TaskRepository: task.done = True self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def acknowledge_task_failure(self, task_id): # find task task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() @@ -181,8 +182,8 @@ class TaskRepository: task.ack = False # ToDo: check race-condition, transaction self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def insert_task( self, payload: db_payload.TaskPayload, @@ -204,8 +205,6 @@ class TaskRepository: ) logger.debug(f"inserting {task=}") self.db.add(task) - self.db.commit() - self.db.refresh(task) return task def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a9a282e2..225b0146 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -9,13 +9,14 @@ import pydantic from loguru import logger from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list from oasst_backend.config import TreeManagerConfiguration, settings -from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state +from oasst_backend.models import Message, MessageReaction, MessageTreeState, Task, TextLabels, message_tree_state from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlalchemy.sql import text -from sqlmodel import Session, func +from sqlmodel import Session, func, not_ class TaskType(Enum): @@ -48,6 +49,7 @@ class ActiveTreeSizeRow(pydantic.BaseModel): class ExtendibleParentRow(pydantic.BaseModel): parent_id: UUID + parent_role: str depth: int message_tree_id: UUID active_children_count: int @@ -58,6 +60,7 @@ class ExtendibleParentRow(pydantic.BaseModel): class IncompleteRankingsRow(pydantic.BaseModel): parent_id: UUID + role: str children_count: int child_min_ranking_count: int @@ -69,21 +72,23 @@ class TreeManager: _all_text_labels = list(map(lambda x: x.value, protocol_schema.TextLabel)) def __init__( - self, db: Session, prompt_repository: PromptRepository, cfg: Optional[TreeManagerConfiguration] = None + self, + db: Session, + prompt_repository: PromptRepository, + cfg: Optional[TreeManagerConfiguration] = None, ): self.db = db self.cfg = cfg or settings.tree_manager self.pr = prompt_repository - def _task_selection( + def _random_task_selection( self, - desired_task_type: protocol_schema.TaskRequestType, num_ranking_tasks: int, num_replies_need_review: int, num_prompts_need_review: int, num_missing_prompts: int, num_missing_replies: int, - ) -> Tuple[TaskType, TaskRole]: + ) -> TaskType: """ Determines which task to hand out to human worker. The task type is drawn with relative weight (e.g. ranking has highest priority) @@ -91,75 +96,97 @@ class TreeManager: """ logger.debug( - f"TreeManager._task_selection({num_ranking_tasks=}, {num_replies_need_review=}, " + f"TreeManager._random_task_selection({num_ranking_tasks=}, {num_replies_need_review=}, " f"{num_prompts_need_review=}, {num_missing_prompts=}, {num_missing_replies=})" ) task_type = TaskType.NONE - task_role = TaskRole.ANY - if desired_task_type == protocol_schema.TaskRequestType.random: - task_weights = [0] * 5 + task_weights = [0] * 5 - if num_ranking_tasks > 0: - task_weights[TaskType.RANKING.value] = 10 + if num_ranking_tasks > 0: + task_weights[TaskType.RANKING.value] = 10 - if num_replies_need_review > 0: - task_weights[TaskType.LABEL_REPLY.value] = 5 + if num_replies_need_review > 0: + task_weights[TaskType.LABEL_REPLY.value] = 5 - if num_prompts_need_review > 0: - task_weights[TaskType.LABEL_PROMPT.value] = 5 + if num_prompts_need_review > 0: + task_weights[TaskType.LABEL_PROMPT.value] = 5 - if num_missing_replies > 0: - task_weights[TaskType.REPLY.value] = 2 + if num_missing_replies > 0: + task_weights[TaskType.REPLY.value] = 2 - if num_missing_prompts > 0: - task_weights[TaskType.PROMPT.value] = 1 + if num_missing_prompts > 0: + task_weights[TaskType.PROMPT.value] = 1 - task_weights = np.array(task_weights) - weight_sum = task_weights.sum() - if weight_sum < 1e-8: - task_type = TaskType.NONE - else: - task_weights = task_weights / weight_sum - task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights)) - else: - match desired_task_type: - case protocol_schema.TaskRequestType.initial_prompt: - if num_missing_prompts > 0: - task_type = TaskType.PROMPT - case protocol_schema.TaskRequestType.label_initial_prompt: - if num_prompts_need_review > 0: - task_type = TaskType.LABEL_PROMPT - case protocol_schema.TaskRequestType.assistant_reply | protocol_schema.TaskRequestType.prompter_reply: - if num_missing_replies > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.assistant_reply - else TaskRole.PROMPTER - ) - task_type = TaskType.REPLY - case protocol_schema.TaskRequestType.label_assistant_reply | protocol_schema.TaskRequestType.label_prompter_reply: - if num_replies_need_review > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.label_assistant_reply - else TaskRole.PROMPTER - ) - task_type = TaskType.LABEL_REPLY - case protocol_schema.TaskRequestType.rank_assistant_replies | protocol_schema.TaskRequestType.rank_prompter_replies: - if num_ranking_tasks > 0: - task_role = ( - TaskRole.ASSISTANT - if desired_task_type == protocol_schema.TaskRequestType.rank_assistant_replies - else TaskRole.PROMPTER - ) - task_type = TaskType.RANKING + task_weights = np.array(task_weights) + weight_sum = task_weights.sum() + if weight_sum > 1e-8: + task_weights = task_weights / weight_sum + task_type = TaskType(np.random.choice(a=len(task_weights), p=task_weights)) - logger.debug(f"Selected {task_type=}, {task_role=}") - return task_type, task_role + logger.debug(f"Selected {task_type=}") + return task_type + + def _determine_task_availability_internal( + self, + num_active_trees: int, + extensible_parents: list[ExtendibleParentRow], + prompts_need_review: list[Message], + replies_need_review: list[Message], + incomplete_rankings: list[IncompleteRankingsRow], + ) -> dict[protocol_schema.TaskRequestType, int]: + task_count_by_type: dict[protocol_schema.TaskRequestType, int] = {t: 0 for t in protocol_schema.TaskRequestType} + + num_missing_prompts = max(0, self.cfg.max_active_trees - num_active_trees) + task_count_by_type[protocol_schema.TaskRequestType.initial_prompt] = num_missing_prompts + + task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( + list(filter(lambda x: x.parent_role == "assistant", extensible_parents)) + ) + task_count_by_type[protocol_schema.TaskRequestType.assistant_reply] = len( + list(filter(lambda x: x.parent_role == "prompter", extensible_parents)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.label_initial_prompt] = len(prompts_need_review) + task_count_by_type[protocol_schema.TaskRequestType.label_assistant_reply] = len( + list(filter(lambda m: m.role == "assistant", replies_need_review)) + ) + task_count_by_type[protocol_schema.TaskRequestType.prompter_reply] = len( + list(filter(lambda m: m.role == "prompter", replies_need_review)) + ) + + if self.cfg.rank_prompter_replies: + task_count_by_type[protocol_schema.TaskRequestType.rank_prompter_replies] = len( + list(filter(lambda r: r.role == "prompter", incomplete_rankings)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.rank_assistant_replies] = len( + list(filter(lambda r: r.role == "assistant", incomplete_rankings)) + ) + + task_count_by_type[protocol_schema.TaskRequestType.random] = sum( + task_count_by_type[t] for t in protocol_schema.TaskRequestType if t in task_count_by_type + ) + + return task_count_by_type + + def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]: + num_active_trees = self.query_num_active_trees() + extensible_parents = self.query_extendible_parents() + prompts_need_review = self.query_prompts_need_review() + replies_need_review = self.query_replies_need_review() + incomplete_rankings = self.query_incomplete_rankings() + + return self._determine_task_availability_internal( + num_active_trees=num_active_trees, + extensible_parents=extensible_parents, + prompts_need_review=prompts_need_review, + replies_need_review=replies_need_review, + incomplete_rankings=incomplete_rankings, + ) def next_task( - self, desired_task_type: protocol_schema.TaskRequestType + self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random ) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]: logger.debug("TreeManager.next_task()") @@ -167,148 +194,195 @@ class TreeManager: num_active_trees = self.query_num_active_trees() prompts_need_review = self.query_prompts_need_review() replies_need_review = self.query_replies_need_review() + extensible_parents = self.query_extendible_parents() + incomplete_rankings = self.query_incomplete_rankings() + if not self.cfg.rank_prompter_replies: + incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings)) + active_tree_sizes = self.query_extendible_trees() # determine type of task to generate num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes) - task_type, task_role = self._task_selection( - desired_task_type, - num_ranking_tasks=len(incomplete_rankings), - num_replies_need_review=len(replies_need_review), - num_prompts_need_review=len(prompts_need_review), - num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), - num_missing_replies=num_missing_replies, - ) - - if task_type == TaskType.NONE: - raise OasstError( - f"No tasks of type '{desired_task_type.value}' are currently available.", - OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, - HTTPStatus.SERVICE_UNAVAILABLE, + task_role = TaskRole.ANY + if desired_task_type == protocol_schema.TaskRequestType.random: + task_type = self._random_task_selection( + num_ranking_tasks=len(incomplete_rankings), + num_replies_need_review=len(replies_need_review), + num_prompts_need_review=len(prompts_need_review), + num_missing_prompts=max(0, self.cfg.max_active_trees - num_active_trees), + num_missing_replies=num_missing_replies, ) - if task_role != TaskRole.ANY: - # Todo: Allow role specific message selection... - raise OasstError( - f"No tasks of type '{desired_task_type.value}' are currently available.", - OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, - HTTPStatus.SERVICE_UNAVAILABLE, + if task_type == TaskType.NONE: + raise OasstError( + f"No tasks of type '{protocol_schema.TaskRequestType.random.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + task_count_by_type = self._determine_task_availability_internal( + num_active_trees=num_active_trees, + extensible_parents=extensible_parents, + prompts_need_review=prompts_need_review, + replies_need_review=replies_need_review, + incomplete_rankings=incomplete_rankings, ) + available_count = task_count_by_type.get(desired_task_type) + if not available_count: + raise OasstError( + f"No tasks of type '{desired_task_type.value}' are currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + + task_type_role_map = { + protocol_schema.TaskRequestType.initial_prompt: (TaskType.PROMPT, TaskRole.ANY), + protocol_schema.TaskRequestType.prompter_reply: (TaskType.REPLY, TaskRole.PROMPTER), + protocol_schema.TaskRequestType.assistant_reply: (TaskType.REPLY, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.rank_prompter_replies: (TaskType.RANKING, TaskRole.PROMPTER), + protocol_schema.TaskRequestType.rank_assistant_replies: (TaskType.RANKING, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.label_initial_prompt: (TaskType.LABEL_PROMPT, TaskRole.ANY), + protocol_schema.TaskRequestType.label_assistant_reply: (TaskType.LABEL_REPLY, TaskRole.ASSISTANT), + protocol_schema.TaskRequestType.label_prompter_reply: (TaskType.LABEL_REPLY, TaskRole.PROMPTER), + } + + task_type, task_role = task_type_role_map[desired_task_type] + message_tree_id = None parent_message_id = None logger.debug(f"selected {task_type=}") match task_type: case TaskType.RANKING: - assert len(incomplete_rankings) > 0 - ranking_parent_id = random.choice(incomplete_rankings).parent_id + if task_role == TaskRole.PROMPTER: + incomplete_rankings = list(filter(lambda m: m.role == "prompter", incomplete_rankings)) + elif task_role == TaskRole.ASSISTANT: + incomplete_rankings = list(filter(lambda m: m.role == "assistant", incomplete_rankings)) - messages = self.pr.fetch_message_conversation(ranking_parent_id) - assert len(messages) > 1 and messages[-1].id == ranking_parent_id - ranking_parent = messages[-1] - assert not ranking_parent.deleted and ranking_parent.review_result - conversation = prepare_conversation(messages) - replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) + if len(incomplete_rankings) > 0: + ranking_parent_id = random.choice(incomplete_rankings).parent_id - assert len(replies) > 1 - random.shuffle(replies) # hand out replies in random order - reply_messages = prepare_conversation_message_list(replies) - replies = [p.text for p in replies] + messages = self.pr.fetch_message_conversation(ranking_parent_id) + assert len(messages) > 1 and messages[-1].id == ranking_parent_id + ranking_parent = messages[-1] + assert not ranking_parent.deleted and ranking_parent.review_result + conversation = prepare_conversation(messages) + replies = self.pr.fetch_message_children(ranking_parent_id, reviewed=True, exclude_deleted=True) - if messages[-1].role == "assistant": - logger.info("Generating a RankPrompterRepliesTask.") - task = protocol_schema.RankPrompterRepliesTask( - conversation=conversation, - replies=replies, - reply_messages=reply_messages, - ranking_parent_id=ranking_parent.id, - message_tree_id=ranking_parent.message_tree_id, - ) - else: - logger.info("Generating a RankAssistantRepliesTask.") - task = protocol_schema.RankAssistantRepliesTask( - conversation=conversation, - replies=replies, - reply_messages=reply_messages, - ranking_parent_id=ranking_parent.id, - message_tree_id=ranking_parent.message_tree_id, - ) + assert len(replies) > 1 + random.shuffle(replies) # hand out replies in random order + reply_messages = prepare_conversation_message_list(replies) + replies = [p.text for p in replies] - parent_message_id = ranking_parent_id - message_tree_id = messages[-1].message_tree_id + if messages[-1].role == "assistant": + logger.info("Generating a RankPrompterRepliesTask.") + task = protocol_schema.RankPrompterRepliesTask( + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, + ) + else: + logger.info("Generating a RankAssistantRepliesTask.") + task = protocol_schema.RankAssistantRepliesTask( + conversation=conversation, + replies=replies, + reply_messages=reply_messages, + ranking_parent_id=ranking_parent.id, + message_tree_id=ranking_parent.message_tree_id, + ) + + parent_message_id = ranking_parent_id + message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_REPLY: - assert len(replies_need_review) > 0 - random_reply_message_id = random.choice(replies_need_review) - messages = self.pr.fetch_message_conversation(random_reply_message_id) + if task_role == TaskRole.PROMPTER: + replies_need_review = list(filter(lambda m: m.role == "prompter", replies_need_review)) + elif task_role == TaskRole.ASSISTANT: + replies_need_review = list(filter(lambda m: m.role == "assistant", replies_need_review)) - conversation = prepare_conversation(messages[:-1]) - message = messages[-1] + if len(replies_need_review) > 0: + random_reply_message = random.choice(replies_need_review) + messages = self.pr.fetch_message_conversation(random_reply_message) - self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 + conversation = prepare_conversation(messages[:-1]) + message = messages[-1] - label_mode = protocol_schema.LabelTaskMode.full - valid_labels = self._all_text_labels + self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 - if message.role == "assistant": - if random.random() > self.cfg.p_full_labeling_review_reply_assistant: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) - label_mode = protocol_schema.LabelTaskMode.simple - logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") - task = protocol_schema.LabelAssistantReplyTask( - message_id=message.id, - conversation=conversation, - reply=message.text, - valid_labels=valid_labels, - mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), - mode=label_mode, - ) - else: - if random.random() > self.cfg.p_full_labeling_review_reply_prompter: - valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) - label_mode = protocol_schema.LabelTaskMode.simple - logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") - task = protocol_schema.LabelPrompterReplyTask( - message_id=message.id, - conversation=conversation, - reply=message.text, - valid_labels=valid_labels, - mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), - mode=label_mode, - ) + label_mode = protocol_schema.LabelTaskMode.full + valid_labels = self._all_text_labels - parent_message_id = message.id - message_tree_id = message.message_tree_id + if message.role == "assistant": + if ( + desired_task_type == protocol_schema.TaskRequestType.random + and random.random() > self.cfg.p_full_labeling_review_reply_assistant + ): + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelAssistantReplyTask. ({label_mode=:s})") + task = protocol_schema.LabelAssistantReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=valid_labels, + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_assistant_reply)), + mode=label_mode, + ) + else: + if ( + desired_task_type == protocol_schema.TaskRequestType.random + and random.random() > self.cfg.p_full_labeling_review_reply_prompter + ): + valid_labels = list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)) + label_mode = protocol_schema.LabelTaskMode.simple + logger.info(f"Generating a LabelPrompterReplyTask. ({label_mode=:s})") + task = protocol_schema.LabelPrompterReplyTask( + message_id=message.id, + conversation=conversation, + reply=message.text, + valid_labels=valid_labels, + mandatory_labels=list(map(lambda x: x.value, self.cfg.mandatory_labels_prompter_reply)), + mode=label_mode, + ) + + parent_message_id = message.id + message_tree_id = message.message_tree_id case TaskType.REPLY: # select a tree with missing replies - extensible_parents = self.query_extendible_parents() - assert len(extensible_parents) > 0 + if task_role == TaskRole.PROMPTER: + extensible_parents = list(filter(lambda x: x.parent_role == "assistant", extensible_parents)) + elif task_role == TaskRole.ASSISTANT: + extensible_parents = list(filter(lambda x: x.parent_role == "prompter", extensible_parents)) - # fetch random conversation to extend - random_parent = random.choice(extensible_parents) - logger.debug(f"selected {random_parent=}") - messages = self.pr.fetch_message_conversation(random_parent.parent_id) - assert all(m.review_result for m in messages) # ensure all messages have positive review - conversation = prepare_conversation(messages) + if len(extensible_parents) > 0: + random_parent = random.choice(extensible_parents) - # generate reply task depending on last message - if messages[-1].role == "assistant": - logger.info("Generating a PrompterReplyTask.") - task = protocol_schema.PrompterReplyTask(conversation=conversation) - else: - logger.info("Generating a AssistantReplyTask.") - task = protocol_schema.AssistantReplyTask(conversation=conversation) + # fetch random conversation to extend + logger.debug(f"selected {random_parent=}") + messages = self.pr.fetch_message_conversation(random_parent.parent_id) + assert all(m.review_result for m in messages) # ensure all messages have positive review + conversation = prepare_conversation(messages) - parent_message_id = messages[-1].id - message_tree_id = messages[-1].message_tree_id + # generate reply task depending on last message + if messages[-1].role == "assistant": + logger.info("Generating a PrompterReplyTask.") + task = protocol_schema.PrompterReplyTask(conversation=conversation) + else: + logger.info("Generating a AssistantReplyTask.") + task = protocol_schema.AssistantReplyTask(conversation=conversation) + + parent_message_id = messages[-1].id + message_tree_id = messages[-1].message_tree_id case TaskType.LABEL_PROMPT: assert len(prompts_need_review) > 0 - message = self.pr.fetch_message(random.choice(prompts_need_review)) + message = random.choice(prompts_need_review) label_mode = protocol_schema.LabelTaskMode.full valid_labels = self._all_text_labels @@ -336,10 +410,18 @@ class TreeManager: case _: task = None + if task is None: + raise OasstError( + f"No task of type '{desired_task_type.value}' is currently available.", + OasstErrorCode.TASK_REQUESTED_TYPE_NOT_AVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE, + ) + logger.info(f"Generated {task=}.") return task, message_tree_id, parent_message_id + @async_managed_tx_method(CommitMode.COMMIT) async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task: pr = self.pr match type(interaction): @@ -358,7 +440,6 @@ class TreeManager: if not message.parent_id: logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}") self._insert_default_state(message.id) - self.db.commit() if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: @@ -428,7 +509,6 @@ class TreeManager: if acceptance_score > self.cfg.acceptance_threshold_initial_prompt: msg.review_result = True self.db.add(msg) - self.db.commit() logger.info( f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" ) @@ -439,7 +519,6 @@ class TreeManager: if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply: msg.review_result = True self.db.add(msg) - self.db.commit() logger.info( f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" ) @@ -451,6 +530,7 @@ class TreeManager: return protocol_schema.TaskDone() + @managed_tx_method(CommitMode.FLUSH) def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State): assert mts and mts.active @@ -460,7 +540,6 @@ class TreeManager: mts.active = False mts.state = state.value self.db.add(mts) - self.db.commit() if is_terminal: logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})") @@ -472,6 +551,7 @@ class TreeManager: mts = self.pr.fetch_tree_state(message_tree_id) self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE) + @managed_tx_method(CommitMode.COMMIT) def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_growing_state({message_tree_id=})") @@ -489,6 +569,7 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.GROWING) return True + @managed_tx_method(CommitMode.COMMIT) def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_ranking_state({message_tree_id=})") @@ -514,7 +595,8 @@ class TreeManager: logger.debug(f"False {mts.active=}, {mts.state=}") return False - rankings_by_message = self.query_tree_ranking_results(message_tree_id) + ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" + rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") @@ -527,68 +609,59 @@ class TreeManager: # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) - _sql_find_prompts_need_review = """ --- find initial prompts that need more reviews -SELECT m.id -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.id -WHERE mts.active - AND mts.state = :state - AND NOT m.review_result - AND NOT m.deleted - AND m.review_count < :num_reviews_initial_prompt - AND m.parent_id is NULL - AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) -""" - - def query_prompts_need_review(self) -> list[UUID]: + def query_prompts_need_review(self) -> list[Message]: """ - Select id of initial prompts with less then required rankings in active message tree + Select initial prompt messages with less then required rankings in active message tree (active == True in message_tree_state) """ - r = self.db.execute( - text(self._sql_find_prompts_need_review), - { - "state": message_tree_state.State.INITIAL_PROMPT_REVIEW, - "num_reviews_initial_prompt": self.cfg.num_reviews_initial_prompt, - "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, - }, + qry = ( + self.db.query(Message) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, + not_(Message.review_result), + not_(Message.deleted), + Message.review_count < self.cfg.num_reviews_initial_prompt, + Message.parent_id.is_(None), + ) ) - return [x["id"] for x in r.all()] - _sql_find_replies_need_review = """ -SELECT m.id -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -WHERE mts.active - AND mts.state = :breeding_state - AND NOT m.review_result - AND NOT m.deleted - AND m.review_count < :num_required_reviews - AND m.parent_id is NOT NULL - AND (:excluded_user_id IS NULL OR m.user_id != :excluded_user_id) -""" + if not settings.DEBUG_ALLOW_SELF_LABELING: + qry = qry.filter(Message.user_id != self.pr.user_id) - def query_replies_need_review(self) -> list[UUID]: + return qry.all() + + def query_replies_need_review(self) -> list[Message]: """ - Select ids of child messages (parent_id IS NOT NULL) with less then required rankings + Select child messages (parent_id IS NOT NULL) with less then required rankings in active message tree (active == True in message_tree_state) """ - r = self.db.execute( - text(self._sql_find_replies_need_review), - { - "breeding_state": message_tree_state.State.GROWING, - "num_required_reviews": self.cfg.num_reviews_reply, - "excluded_user_id": None if settings.DEBUG_ALLOW_SELF_LABELING else self.pr.user_id, - }, + qry = ( + self.db.query(Message) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + MessageTreeState.state == message_tree_state.State.GROWING, + not_(Message.review_result), + not_(Message.deleted), + Message.review_count < self.cfg.num_reviews_reply, + Message.parent_id.is_not(None), + ) ) - return [x["id"] for x in r.all()] + + if not settings.DEBUG_ALLOW_SELF_LABELING: + qry = qry.filter(Message.user_id != self.pr.user_id) + + return qry.all() _sql_find_incomplete_rankings = """ -- find incomplete rankings -SELECT m.parent_id, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, +SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings FROM message_tree_state mts LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id @@ -597,7 +670,7 @@ WHERE mts.active -- only consider active trees AND m.review_result -- must be reviewed AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts -GROUP BY m.parent_id +GROUP BY m.parent_id, m.role HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ @@ -615,10 +688,10 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings _sql_find_extendible_parents = """ -- find all extendible parent nodes -SELECT m.id as parent_id, m.depth, m.message_tree_id, COUNT(c.id) active_children_count +SELECT m.id as parent_id, m.role as parent_role, m.depth, m.message_tree_id, COUNT(c.id) active_children_count FROM message_tree_state mts LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -- all elements of message tree - LEFT JOIN message c ON m.id = c.Id -- child nodes + LEFT JOIN message c ON m.id = c.parent_id -- child nodes WHERE mts.active -- only consider active trees AND mts.state = :growing_state -- message tree must be growing AND NOT m.deleted -- ignore deleted messages as parents @@ -626,7 +699,7 @@ WHERE mts.active -- only consider active trees AND m.review_result -- parent node must have positive review AND NOT c.deleted -- don't count deleted children AND (c.review_result OR c.review_count < :num_reviews_reply) -- don't count children with negative review but count elements under review -GROUP BY m.id, m.depth, m.message_tree_id, mts.max_children_count +GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children """ @@ -635,10 +708,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children r = self.db.execute( text(self._sql_find_extendible_parents), - { - "growing_state": message_tree_state.State.GROWING, - "num_reviews_reply": self.cfg.num_reviews_reply, - }, + {"growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply}, ) return [ExtendibleParentRow.from_orm(x) for x in r.all()] @@ -670,21 +740,27 @@ HAVING COUNT(m.id) < mts.goal_tree_size ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] - _sql_get_tree_size = """ -SELECT mts.message_tree_id, mts.goal_tree_size, COUNT(m.id) AS tree_size -FROM message_tree_state mts - LEFT JOIN message m ON mts.message_tree_id = m.message_tree_id -WHERE mts.active - AND NOT m.deleted - AND m.review_result - AND mts.message_tree_id = :message_tree_id -GROUP BY mts.message_tree_id, mts.goal_tree_size -""" - def query_tree_size(self, message_tree_id: UUID) -> ActiveTreeSizeRow: """Returns the number of reviewed not deleted messages in the message tree.""" - r = self.db.execute(text(self._sql_get_tree_size), {"message_tree_id": message_tree_id}) - return ActiveTreeSizeRow.from_orm(r.one()) + + qry = ( + self.db.query( + MessageTreeState.message_tree_id.label("message_tree_id"), + MessageTreeState.goal_tree_size.label("goal_tree_size"), + func.count(Message.id).label("tree_size"), + ) + .select_from(MessageTreeState) + .outerjoin(Message, MessageTreeState.message_tree_id == Message.message_tree_id) + .filter( + MessageTreeState.active, + not_(Message.deleted), + Message.review_result, + MessageTreeState.message_tree_id == message_tree_id, + ) + .group_by(MessageTreeState.message_tree_id, MessageTreeState.goal_tree_size) + ) + + return ActiveTreeSizeRow.from_orm(qry.one()) def query_misssing_tree_states(self) -> list[UUID]: """Find all initial prompt messages that have no associated message tree state""" @@ -701,7 +777,7 @@ GROUP BY mts.message_tree_id, mts.goal_tree_size return [m.id for m in qry_missing_tree_states.all()] _sql_find_tree_ranking_results = """ --- get all ranking results of completed tasks for all parents with >=2 children +-- get all ranking results of completed tasks for all parents with >= 2 children SELECT p.parent_id, mr.* FROM ( -- find parents with > 1 children @@ -711,7 +787,8 @@ SELECT p.parent_id, mr.* FROM WHERE m.review_result -- must be reviewed AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts - AND mts.message_tree_id = :message_tree_id + AND (:role IS NULL OR m.role = :role) -- children with matching role + AND mts.message_tree_id = :message_tree_id GROUP BY m.parent_id, m.message_tree_id HAVING COUNT(m.id) > 1 ) as p @@ -719,11 +796,21 @@ LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' """ - def query_tree_ranking_results(self, message_tree_id: UUID) -> dict[UUID, list[MessageReaction]]: + def query_tree_ranking_results( + self, + message_tree_id: UUID, + role_filter: str = "assistant", + ) -> dict[UUID, list[MessageReaction]]: """Finds all completed ranking restuls for a message_tree""" + + assert role_filter in (None, "assistant", "prompter") + r = self.db.execute( text(self._sql_find_tree_ranking_results), - {"message_tree_id": message_tree_id}, + { + "message_tree_id": message_tree_id, + "role": role_filter, + }, ) rankings_by_message = {} @@ -735,6 +822,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin rankings_by_message[parent_id].append(MessageReaction.from_orm(x)) return rankings_by_message + @managed_tx_method(CommitMode.COMMIT) def ensure_tree_states(self): """Add message tree state rows for all root nodes (inital prompt messages).""" @@ -746,23 +834,21 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin state = message_tree_state.State.GROWING logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})") self._insert_default_state(id, state=state) - self.db.commit() def query_num_active_trees(self) -> int: query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active) return query.scalar() def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]: - sql_qry = """ -SELECT tl.* -FROM task t - INNER JOIN text_labels tl ON tl.id = t.id -WHERE t.done = TRUE - AND tl.message_id = :message_id -""" - r = self.db.execute(text(sql_qry), {"message_id": message_id}) - return [TextLabels.from_orm(x) for x in r.all()] + qry = ( + self.db.query(TextLabels) + .select_from(Task) + .join(TextLabels, Task.id == TextLabels.id) + .filter(Task.done, TextLabels.message_id == message_id) + ) + return qry.all() + @managed_tx_method(CommitMode.FLUSH) def _insert_tree_state( self, root_message_id: UUID, @@ -784,6 +870,7 @@ WHERE t.done = TRUE self.db.add(model) return model + @managed_tx_method(CommitMode.FLUSH) def _insert_default_state( self, root_message_id: UUID, @@ -800,12 +887,12 @@ WHERE t.done = TRUE if __name__ == "__main__": - from oasst_backend.api.deps import get_dummy_api_client + from oasst_backend.api.deps import api_auth from oasst_backend.database import engine from oasst_backend.prompt_repository import PromptRepository with Session(engine) as db: - api_client = get_dummy_api_client(db) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, client_user=dummy_user) @@ -814,15 +901,21 @@ if __name__ == "__main__": tm = TreeManager(db, pr, cfg) tm.ensure_tree_states() - print("query_num_active_trees", tm.query_num_active_trees()) - print("query_incomplete_rankings", tm.query_incomplete_rankings()) - print("query_incomplete_reply_reviews", tm.query_replies_need_review()) - print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) - print("query_extendible_trees", tm.query_extendible_trees()) - print("query_extendible_parents", tm.query_extendible_parents()) - - print("next_task:", tm.next_task()) + # print("query_num_active_trees", tm.query_num_active_trees()) + # print("query_incomplete_rankings", tm.query_incomplete_rankings()) + # print("query_replies_need_review", tm.query_replies_need_review()) + # print("query_incomplete_initial_prompt_reviews", tm.query_prompts_need_review()) + # print("query_extendible_trees", tm.query_extendible_trees()) + # print("query_extendible_parents", tm.query_extendible_parents()) + # print("query_tree_size", tm.query_tree_size(message_tree_id=UUID("bdf434cf-4df5-4b74-949c-a5a157bc3292"))) print( - ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) + "query_reviews_for_message", + tm.query_reviews_for_message(message_id=UUID("6a444493-0d48-4316-a9f1-7e263f5a2473")), ) + + # print("next_task:", tm.next_task()) + + # print( + # "query_tree_ranking_results", tm.query_tree_ranking_results(UUID("6036f58f-41b5-48c4-bdd9-b16f34ab1312")) + # ) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 8d8a96d5..578dc5f1 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -2,6 +2,7 @@ from typing import Optional from uuid import UUID from oasst_backend.models import ApiClient, User +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -62,6 +63,7 @@ class UserRepository: return user + @managed_tx_method(CommitMode.COMMIT) def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None: """ Update a user by global user ID to disable or set admin notes. Only trusted clients may update users. @@ -83,8 +85,8 @@ class UserRepository: user.notes = notes self.db.add(user) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def mark_user_deleted(self, id: UUID) -> None: """ Update a user by global user ID to set deleted flag. Only trusted clients may delete users. @@ -103,8 +105,8 @@ class UserRepository: user.deleted = True self.db.add(user) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: if not client_user: return None @@ -127,13 +129,10 @@ class UserRepository: auth_method=client_user.auth_method, ) self.db.add(user) - self.db.commit() - self.db.refresh(user) elif client_user.display_name and client_user.display_name != user.display_name: # we found the user but the display name changed user.display_name = client_user.display_name self.db.add(user) - self.db.commit() return user def query_users( diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py new file mode 100644 index 00000000..d378c6a6 --- /dev/null +++ b/backend/oasst_backend/utils/database_utils.py @@ -0,0 +1,138 @@ +from enum import IntEnum +from functools import wraps +from http import HTTPStatus +from typing import Callable + +from loguru import logger +from oasst_backend.config import settings +from oasst_backend.database import engine +from oasst_shared.exceptions import OasstError, OasstErrorCode +from sqlalchemy.exc import OperationalError +from sqlmodel import Session, SQLModel + + +class CommitMode(IntEnum): + """ + Commit modes for the managed tx methods + """ + + NONE = 0 + FLUSH = 1 + COMMIT = 2 + + +""" +* managed_tx_method and async_managed_tx_method methods are decorators functions +* to be used on class functions. It expects the Class to have a 'db' Session object +* initialised +* TODO: tx method decorator for non class methods +""" + + +def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT): + def decorator(f): + @wraps(f) + def wrapped_f(self, *args, **kwargs): + try: + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + self.db.commit() + elif auto_commit == CommitMode.FLUSH: + self.db.flush() + if isinstance(result, SQLModel): + self.db.refresh(result) + return result + except OperationalError: + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + self.db.rollback() + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.error("DB Rollback Failure") + raise e + + return wrapped_f + + return decorator + + +def async_managed_tx_method( + auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT +): + def decorator(f): + @wraps(f) + async def wrapped_f(self, *args, **kwargs): + try: + for i in range(num_retries): + try: + result = await f(self, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + self.db.commit() + elif auto_commit == CommitMode.FLUSH: + self.db.flush() + if isinstance(result, SQLModel): + self.db.refresh(result) + return result + except OperationalError: + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + self.db.rollback() + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.exception("DB Rollback Failure") + raise e + + return wrapped_f + + return decorator + + +def default_session_factor() -> Session: + return Session(engine) + + +def managed_tx_function( + auto_commit: CommitMode = CommitMode.COMMIT, + num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, + session_factory: Callable[..., Session] = default_session_factor, +): + """Passes Session object as first argument to wrapped function.""" + + def decorator(f): + @wraps(f) + def wrapped_f(*args, **kwargs): + try: + for i in range(num_retries): + with session_factory() as session: + try: + result = f(session, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + session.commit() + elif auto_commit == CommitMode.FLUSH: + session.flush() + if isinstance(result, SQLModel): + session.refresh(result) + return result + except OperationalError: + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + session.rollback() + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.error("DB Rollback Failure") + raise e + + return wrapped_f + + return decorator diff --git a/deploy/prod-node/nginx/nginx.conf b/deploy/prod-node/nginx/nginx.conf index a8b34cb7..b14290fc 100644 --- a/deploy/prod-node/nginx/nginx.conf +++ b/deploy/prod-node/nginx/nginx.conf @@ -16,6 +16,19 @@ http { } } + server { + listen 443 ssl http2; + + server_name www.open-assistant.io; + + ssl_certificate /etc/nginx/ssl/live/www.open-assistant.io/fullchain.pem; + ssl_certificate_key /etc/nginx/ssl/live/www.open-assistant.io/privkey.pem; + + location / { + return 301 https://open-assistant.io$request_uri; + } + } + server { listen 443 ssl http2; @@ -25,7 +38,9 @@ http { ssl_certificate_key /etc/nginx/ssl/live/open-assistant.io/privkey.pem; location / { - return 301 https://web.prod.open-assistant.io$request_uri; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_pass http://127.0.0.1:3200; } } diff --git a/docs/docs/research/general.md b/docs/docs/research/general.md index 4186ebac..62ace821 100644 --- a/docs/docs/research/general.md +++ b/docs/docs/research/general.md @@ -8,15 +8,29 @@ This page lists research papers that are relevant to the project. - Generating Text From Language Models - Automatically Generating Instruction Data for Training - Uncertainty Estimation of Language Model Outputs +- Evidence-Guided Text Generation +- Reward Model Optimization +- Dialogue-Oriented RLHF +- Reduce Harms in Language Models -## Reinforcement Learning from Human Feedback +## Reinforcement Learning from Human Feedback Reinforcement Learning from Human Feedback (RLHF) is a method for fine-tuning a generative language models based on a reward model that is learned from human preference data. This method facilitates the learning of instruction-tuned models, among other things. -### Learning to summarize from human feedback [[ArXiv](https://arxiv.org/pdf/2009.01325.pdf)], [[Github](https://github.com/openai/summarize-from-feedback)] +### Fine-Tuning Language Models from Human Preferences [[ArXiv](https://arxiv.org/abs/1909.08593)], [[GitHub](https://github.com/openai/lm-human-preferences)] + +> In this paper, we build on advances in generative pretraining of language +> models to apply reward learning to four natural language tasks: continuing +> text with positive sentiment or physically descriptive language, and +> summarization tasks on the TL;DR and CNN/Daily Mail datasets. For stylistic +> continuation we achieve good results with only 5,000 comparisons evaluated by +> humans. For summarization, models trained with 60,000 comparisons copy whole +> sentences from the input but skip irrelevant preamble. + +### Learning to summarize from human feedback [[ArXiv](https://arxiv.org/abs/2009.01325)], [[GitHub](https://github.com/openai/summarize-from-feedback)] > In this work, we show that it is possible to significantly improve summary > quality by training a model to optimize for human preferences. We collect a @@ -24,7 +38,18 @@ models, among other things. > model to predict the human-preferred summary, and use that model as a reward > function to fine-tune a summarization policy using reinforcement learning. -### Training language models to follow instructions with human feedback [[ArXiv](https://arxiv.org/pdf/2203.02155.pdf)] +### Recursively Summarizing Books with Human Feedback [[ArXiv](https://arxiv.org/abs/2109.10862)] + +> Our method combines learning from human feedback with recursive task +> decomposition: we use models trained on smaller parts of the task to assist +> humans in giving feedback on the broader task. We collect a large volume of +> demonstrations and comparisons from human labelers. Our resulting model +> generates sensible summaries of entire books, even matching the quality of +> human-written summaries in a few cases (∼5% of books). We achieve +> state-of-the-art results on the recent BookSum dataset for book-length +> summarization. We release datasets of samples from our model. + +### Training language models to follow instructions with human feedback [[ArXiv](https://arxiv.org/abs/2203.02155)] > Starting with a set of labeler-written prompts and prompts submitted through > the OpenAI API, we collect a dataset of labeler demonstrations of the desired @@ -33,7 +58,7 @@ models, among other things. > fine-tune this supervised model using reinforcement learning from human > feedback. -### Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback [[ArXiv](https://arxiv.org/pdf/2204.05862.pdf)] +### Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback [[ArXiv](https://arxiv.org/abs/2204.05862)] > We apply preference modeling and reinforcement learning from human feedback > (RLHF) to finetune language models to act as helpful and harmless assistants. @@ -41,6 +66,31 @@ models, among other things. > evaluations, and is fully compatible with training for specialized skills such > as python coding and summarization. +### Self-critiquing models for assisting human evaluators [[ArXiv](https://arxiv.org/abs/2206.05802)] + +> We fine-tune large language models to write natural language critiques +> (natural language critical comments) using behavioral cloning. On a +> topic-based summarization task, critiques written by our models help humans +> find flaws in summaries that they would have otherwise missed. We study +> scaling properties of critiquing with both topic-based summarization and +> synthetic tasks. Finally, we motivate and introduce a framework for comparing +> critiquing ability to generation and discrimination ability. These results are +> a proof of concept for using AI-assisted human feedback to scale the +> supervision of machine learning systems to tasks that are difficult for humans +> to evaluate directly. We release our training datasets. + +### Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization [[ArXiv](https://arxiv.org/abs/2210.01241)] + +> We tackle the problem of aligning pre-trained large language models (LMs) with +> human preferences. We present the GRUE (General Reinforced-language +> Understanding Evaluation) benchmark, a set of 6 language generation tasks +> which are supervised by reward functions which capture automated measures of +> human preference. Finally, we introduce an easy-to-use, performant RL +> algorithm, NLPO (Natural Language Policy Optimization) that learns to +> effectively reduce the combinatorial action space in language generation. We +> show that RL techniques are generally better than supervised methods at +> aligning LMs to human preferences. + ## Generating Text From Language Models A language model generates output text token by token, autoregressively. The @@ -48,7 +98,7 @@ large search space of this task requires some method of narrowing down the set of tokens to be considered in each step. This method, in turn, has a big impact on the quality of the resulting text. -### RANKGEN: Improving Text Generation with Large Ranking Models [[ArXiv](https://arxiv.org/pdf/2205.09726.pdf)], [[Github](https://github.com/martiansideofthemoon/rankgen)] +### RANKGEN: Improving Text Generation with Large Ranking Models [[ArXiv](https://arxiv.org/abs/2205.09726)], [[GitHub](https://github.com/martiansideofthemoon/rankgen)] > Given an input sequence (or prefix), modern language models often assign high > probabilities to output sequences that are repetitive, incoherent, or @@ -65,7 +115,7 @@ annotated data for the purpose of training [instruction-aligned](https://openai.com/blog/instruction-following/) language models. -### SELF-INSTRUCT: Aligning Language Model with Self Generated Instructions [[ArXiv](https://arxiv.org/pdf/2212.10560.pdf)], [[Github](https://github.com/yizhongw/self-instruct)]. +### SELF-INSTRUCT: Aligning Language Model with Self Generated Instructions [[ArXiv](https://arxiv.org/abs/2212.10560)], [[GitHub](https://github.com/yizhongw/self-instruct)]. > We introduce SELF-INSTRUCT, a framework for improving the > instruction-following capabilities of pretrained language models by @@ -76,7 +126,7 @@ models. > SuperNaturalInstructions, on par with the performance of InstructGPT-0011, > which is trained with private user data and human annotations. -### Tuning Language Models with (Almost) No Human Labor. [[ArXiv](https://arxiv.org/pdf/2212.09689.pdf)], [[Github](https://github.com/orhonovich/unnatural-instructions)]. +### Tuning Language Models with (Almost) No Human Labor. [[ArXiv](https://arxiv.org/abs/2212.09689)], [[GitHub](https://github.com/orhonovich/unnatural-instructions)]. > In this work, we introduce Unnatural Instructions: a large dataset of creative > and diverse instructions, collected with virtually no human labor. We collect @@ -91,7 +141,7 @@ models. ## Uncertainty Estimation of Language Model Outputs -### Teaching models to express their uncertainty in words [[Arxiv](https://arxiv.org/pdf/2205.14334.pdf)] +### Teaching models to express their uncertainty in words [[ArXiv](https://arxiv.org/abs/2205.14334)] > We show that a GPT-3 model can learn to express uncertainty about its own > answers in natural language -- without use of model logits. When given a @@ -100,3 +150,69 @@ models. > are well calibrated. The model also remains moderately calibrated under > distribution shift, and is sensitive to uncertainty in its own answers, rather > than imitating human examples. + +## Evidence-Guided Text Generation + +### WebGPT: Browser-assisted question-answering with human feedback [[ArXiv](https://arxiv.org/abs/2112.09332)] + +> We fine-tune GPT-3 to answer long-form questions using a text-based +> web-browsing environment, which allows the model to search and navigate the +> web. We are able to train models on the task using imitation learning, and +> then optimize answer quality with human feedback. Models must collect +> references while browsing in support of their answers. Our best model is +> obtained by fine-tuning GPT-3 using behavior cloning, and then performing +> rejection sampling against a reward model. + +### Teaching language models to support answers with verified quotes [[ArXiv](https://arxiv.org/abs/2203.11147)] + +> In this work we use RLHF to train "open-book" QA models that generate answers +> whilst also citing specific evidence for their claims, which aids in the +> appraisal of correctness. Supporting evidence is drawn from multiple documents +> found via a search engine, or from a single user-provided document. However, +> analysis on the adversarial TruthfulQA dataset shows why citation is only one +> part of an overall strategy for safety and trustworthiness: not all claims +> supported by evidence are true. + +## Reward Model Optimization + +### Scaling Laws for Reward Model Overoptimization [[ArXiv](https://arxiv.org/abs/2210.10760)], [[Preceding Blogpost](https://openai.com/blog/measuring-goodharts-law/)] + +> In this work, we use a synthetic setup in which a fixed "gold-standard" reward +> model plays the role of humans, providing labels used to train a proxy reward +> model. We study how the gold reward model score changes as we optimize against +> the proxy reward model using either reinforcement learning or best-of-n +> sampling. We study the effect on this relationship of the size of the reward +> model dataset. We explore the implications of these empirical results for +> theoretical considerations in AI alignment. + +## Dialogue-Oriented RLHF + +### Dynamic Planning in Open-Ended Dialogue using Reinforcement Learning [[ArXiv](https://arxiv.org/abs/2208.02294)] + +> Building automated agents that can carry on rich open-ended conversations with +> humans "in the wild" remains a formidable challenge. In this work we develop a +> real-time, open-ended dialogue system that uses reinforcement learning (RL) to +> power a bot's conversational skill at scale. Trained using crowd-sourced data, +> our novel system is able to substantially exceeds several metrics of interest +> in a live experiment with real users of the Google Assistant. + +### Improving alignment of dialogue agents via targeted human judgements [[ArXiv](https://arxiv.org/abs/2209.14375)] + +> We present Sparrow, an information-seeking dialogue agent trained to be more +> helpful, correct, and harmless compared to prompted language model baselines +> First, to make our agent more helpful and harmless, we break down the +> requirements for good dialogue into natural language rules the agent should +> followy. Second, our agent provides evidence from sources supporting factual +> claims when collecting preference judgements over model statements.Finally, we +> conduct extensive analyses showing that though our model learns to follow our +> rules it can exhibit distributional biases. + +## Reduce Harms in Language Models + +### Red Teaming Language Models to Reduce Harms: Methods, Scaling Behaviors, and Lessons Learned [[ArXiv](https://arxiv.org/abs/2209.07858)] + +> We investigate scaling behaviors for red teaming. We find that the RLHF models +> are increasingly difficult to red team as they scale, and we find a flat trend +> with scale for the other model types. We exhaustively describe our +> instructions, processes, statistical methodologies, and uncertainty about red +> teaming. diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 31ba00f6..e60ad746 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -18,6 +18,7 @@ class OasstErrorCode(IntEnum): DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 ROOT_TOKEN_NOT_AUTHORIZED = 3 + DATABASE_MAX_RETRIES_EXHAUSTED = 4 TOO_MANY_REQUESTS = 429 SERVER_ERROR0 = 500 @@ -31,6 +32,7 @@ class OasstErrorCode(IntEnum): TASK_INTERACTION_REQUEST_FAILED = 1004 TASK_GENERATION_FAILED = 1005 TASK_REQUESTED_TYPE_NOT_AVAILABLE = 1006 + TASK_AVAILABILITY_QUERY_FAILED = 1007 # 2000-3000: prompt_repository INVALID_FRONTEND_MESSAGE_ID = 2000 diff --git a/website/package-lock.json b/website/package-lock.json index 3f3d1870..1fa3d14d 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -29,7 +29,6 @@ "eslint-config-next": "13.0.6", "eslint-plugin-simple-import-sort": "^8.0.0", "focus-visible": "^5.2.0", - "formik": "^2.2.9", "framer-motion": "^6.5.1", "install": "^0.13.0", "next": "13.0.6", @@ -40,6 +39,7 @@ "react": "18.2.0", "react-dom": "18.2.0", "react-feature-flags": "^1.0.0", + "react-hook-form": "^7.42.1", "react-icons": "^4.7.1", "react-table": "^7.8.0", "sharp": "^0.31.3", @@ -20304,47 +20304,6 @@ "node": ">= 6" } }, - "node_modules/formik": { - "version": "2.2.9", - "resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz", - "integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==", - "funding": [ - { - "type": "individual", - "url": "https://opencollective.com/formik" - } - ], - "dependencies": { - "deepmerge": "^2.1.1", - "hoist-non-react-statics": "^3.3.0", - "lodash": "^4.17.21", - "lodash-es": "^4.17.21", - "react-fast-compare": "^2.0.1", - "tiny-warning": "^1.0.2", - "tslib": "^1.10.0" - }, - "peerDependencies": { - "react": ">=16.8.0" - } - }, - "node_modules/formik/node_modules/deepmerge": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz", - "integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/formik/node_modules/react-fast-compare": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz", - "integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw==" - }, - "node_modules/formik/node_modules/tslib": { - "version": "1.14.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", - "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==" - }, "node_modules/forwarded": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", @@ -26443,12 +26402,8 @@ "node_modules/lodash": { "version": "4.17.21", "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" - }, - "node_modules/lodash-es": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz", - "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==" + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "dev": true }, "node_modules/lodash.debounce": { "version": "4.0.8", @@ -32527,6 +32482,21 @@ "node": ">=10" } }, + "node_modules/react-hook-form": { + "version": "7.42.1", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.42.1.tgz", + "integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==", + "engines": { + "node": ">=12.22.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/react-hook-form" + }, + "peerDependencies": { + "react": "^16.8.0 || ^17 || ^18" + } + }, "node_modules/react-icons": { "version": "4.7.1", "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz", @@ -35486,11 +35456,6 @@ "resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz", "integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==" }, - "node_modules/tiny-warning": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz", - "integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA==" - }, "node_modules/tmp": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz", @@ -52929,37 +52894,6 @@ "mime-types": "^2.1.12" } }, - "formik": { - "version": "2.2.9", - "resolved": "https://registry.npmjs.org/formik/-/formik-2.2.9.tgz", - "integrity": "sha512-LQLcISMmf1r5at4/gyJigGn0gOwFbeEAlji+N9InZF6LIMXnFNkO42sCI8Jt84YZggpD4cPWObAZaxpEFtSzNA==", - "requires": { - "deepmerge": "^2.1.1", - "hoist-non-react-statics": "^3.3.0", - "lodash": "^4.17.21", - "lodash-es": "^4.17.21", - "react-fast-compare": "^2.0.1", - "tiny-warning": "^1.0.2", - "tslib": "^1.10.0" - }, - "dependencies": { - "deepmerge": { - "version": "2.2.1", - "resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-2.2.1.tgz", - "integrity": "sha512-R9hc1Xa/NOBi9WRVUWg19rl1UB7Tt4kuPd+thNJgFZoxXsTz7ncaPaeIm+40oSGuP33DfMb4sZt1QIGiJzC4EA==" - }, - "react-fast-compare": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/react-fast-compare/-/react-fast-compare-2.0.4.tgz", - "integrity": "sha512-suNP+J1VU1MWFKcyt7RtjiSWUjvidmQSlqu+eHslq+342xCbGTYmC0mEhPCOHxlW0CywylOC1u2DFAT+bv4dBw==" - }, - "tslib": { - "version": "1.14.1", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", - "integrity": "sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==" - } - } - }, "forwarded": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", @@ -57574,12 +57508,8 @@ "lodash": { "version": "4.17.21", "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" - }, - "lodash-es": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz", - "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==" + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "dev": true }, "lodash.debounce": { "version": "4.0.8", @@ -61952,6 +61882,12 @@ } } }, + "react-hook-form": { + "version": "7.42.1", + "resolved": "https://registry.npmjs.org/react-hook-form/-/react-hook-form-7.42.1.tgz", + "integrity": "sha512-2UIGqwMZksd5HS55crTT1ATLTr0rAI4jS7yVuqTaoRVDhY2Qc4IyjskCmpnmdYqUNOYFy04vW253tb2JRVh+IQ==", + "requires": {} + }, "react-icons": { "version": "4.7.1", "resolved": "https://registry.npmjs.org/react-icons/-/react-icons-4.7.1.tgz", @@ -64265,11 +64201,6 @@ "resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.1.tgz", "integrity": "sha512-AD5ih2NlSssTCwsMznbvwMZpJ1cbhkGd2uueNxzv2jDlEeZdU04JQfRnggJQ8DrcVBGjAsCKwFBbDlVNtEMlzw==" }, - "tiny-warning": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/tiny-warning/-/tiny-warning-1.0.3.tgz", - "integrity": "sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA==" - }, "tmp": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz", diff --git a/website/package.json b/website/package.json index 855feb1d..580d0be3 100644 --- a/website/package.json +++ b/website/package.json @@ -46,7 +46,6 @@ "eslint-config-next": "13.0.6", "eslint-plugin-simple-import-sort": "^8.0.0", "focus-visible": "^5.2.0", - "formik": "^2.2.9", "framer-motion": "^6.5.1", "install": "^0.13.0", "next": "13.0.6", @@ -57,6 +56,7 @@ "react": "18.2.0", "react-dom": "18.2.0", "react-feature-flags": "^1.0.0", + "react-hook-form": "^7.42.1", "react-icons": "^4.7.1", "react-table": "^7.8.0", "sharp": "^0.31.3", diff --git a/website/src/components/CollapsableText.tsx b/website/src/components/CollapsableText.tsx index 8d2dff56..21325d29 100644 --- a/website/src/components/CollapsableText.tsx +++ b/website/src/components/CollapsableText.tsx @@ -11,6 +11,8 @@ import { } from "@chakra-ui/react"; import React, { ReactNode } from "react"; +const killEvent = (e) => e.stopPropagation(); + export const CollapsableText = ({ text, maxLength = 220, @@ -44,8 +46,9 @@ export const CollapsableText = ({ - - + {/* we kill the event here to disable drag and drop, since it is in the same container */} + + Full Text {text} diff --git a/website/src/components/Dashboard/LeaderboardTable.tsx b/website/src/components/Dashboard/LeaderboardTable.tsx index b6048f3f..52cf762b 100644 --- a/website/src/components/Dashboard/LeaderboardTable.tsx +++ b/website/src/components/Dashboard/LeaderboardTable.tsx @@ -1,17 +1,16 @@ -import { Box, Link, Stack, StackDivider, Text, useColorModeValue } from "@chakra-ui/react"; +import { Box, Link, Text, useColorModeValue } from "@chakra-ui/react"; import NextLink from "next/link"; -import { get } from "src/lib/api"; -import useSWR from "swr"; +import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; +import { LeaderboardTimeFrame } from "src/types/Leaderboard"; export function LeaderboardTable() { const backgroundColor = useColorModeValue("white", "gray.700"); const accentColor = useColorModeValue("gray.200", "gray.900"); - const { data: leaderboardEntries } = useSWR("/api/leaderboard", get); return (
- Top 5 Contributors + Top 5 Contributors Today View All -> @@ -25,30 +24,7 @@ export function LeaderboardTable() { borderRadius="xl" className="p-6 shadow-sm" > - } spacing="4"> -
-

Name

-
-

Score

-
-
- {leaderboardEntries?.map(({ display_name, score }, idx) => ( -
-
- {/* - Profile Picture - */} -

{display_name}

- {/* - {item.streakCount} - */} -
- -

{score}

-
-
- ))} -
+
diff --git a/website/src/components/Dashboard/SlimFooter.tsx b/website/src/components/Dashboard/SlimFooter.tsx new file mode 100644 index 00000000..5a7b093c --- /dev/null +++ b/website/src/components/Dashboard/SlimFooter.tsx @@ -0,0 +1,45 @@ +import { Box, Divider } from "@chakra-ui/react"; +import Image from "next/image"; +import Link from "next/link"; +import { useMemo } from "react"; + +export function SlimFooter() { + return ( +
+ + + + + + logo + + + + + +
+ ); +} + +const FooterLink = ({ href, label }: { href: string; label: string }) => + useMemo( + () => ( + + {label} + + ), + [href, label] + ); diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx index 4cd55293..b9e41a32 100644 --- a/website/src/components/FlaggableElement.tsx +++ b/website/src/components/FlaggableElement.tsx @@ -22,7 +22,7 @@ import { } from "@chakra-ui/react"; import { QuestionMarkCircleIcon } from "@heroicons/react/20/solid"; import clsx from "clsx"; -import { useReducer } from "react"; +import { useEffect, useReducer } from "react"; import { FiAlertCircle } from "react-icons/fi"; import { get, post } from "src/lib/api"; import { Message } from "src/types/Conversation"; @@ -100,12 +100,14 @@ export const FlaggableElement = (props: FlaggableElementProps) => { ); const [isEditing, setIsEditing] = useBoolean(); - useSWR("/api/valid_labels", get, { - onSuccess: (data) => { - const { valid_labels } = data; - updateReport({ type: "load_labels", labels: valid_labels }); - }, - }); + const { data, isLoading } = useSWR("/api/valid_labels", get); + useEffect(() => { + if (isLoading) { + return; + } + const { valid_labels } = data; + updateReport({ type: "load_labels", labels: valid_labels }); + }, [data, isLoading]); const { trigger } = useSWRMutation("/api/set_label", post, { onSuccess: () => { diff --git a/website/src/components/Footer.tsx b/website/src/components/Footer.tsx index 9332461d..b22653e9 100644 --- a/website/src/components/Footer.tsx +++ b/website/src/components/Footer.tsx @@ -1,40 +1,66 @@ -import { useColorMode } from "@chakra-ui/react"; +import { Box, Divider, Flex, Text, useColorMode } from "@chakra-ui/react"; import Image from "next/image"; import Link from "next/link"; import { useMemo } from "react"; export function Footer() { const { colorMode } = useColorMode(); - const bgColorClass = colorMode === "light" ? "bg-transparent" : "bg-gray-800"; - const borderClass = colorMode === "light" ? "border-slate-200" : "border-transparent"; + const backgroundColor = colorMode === "light" ? "white" : "gray.800"; + const textColor = colorMode === "light" ? "black" : "gray.300"; return ( -
-
-
- - logo - +
+ + + + + + + logo + + -
-

Open Assistant

-

Conversational AI for everyone.

-
-
+ + + Open Assistant + + + Conversational AI for everyone. + + + - -
+ + +
); } @@ -42,14 +68,10 @@ export function Footer() { const FooterLink = ({ href, label }: { href: string; label: string }) => useMemo( () => ( - - {label} + + + {label} + ), [href, label] diff --git a/website/src/components/Layout.tsx b/website/src/components/Layout.tsx index b26b9bf2..70a2ce2c 100644 --- a/website/src/components/Layout.tsx +++ b/website/src/components/Layout.tsx @@ -1,9 +1,11 @@ // https://nextjs.org/docs/basic-features/layouts +import { Box, Grid } from "@chakra-ui/react"; import type { NextPage } from "next"; import { FiBarChart2, FiLayout, FiMessageSquare, FiUsers } from "react-icons/fi"; import { Header } from "src/components/Header"; +import { SlimFooter } from "./Dashboard/SlimFooter"; import { Footer } from "./Footer"; import { SideMenuLayout } from "./SideMenuLayout"; @@ -28,7 +30,7 @@ export const getTransparentHeaderLayout = (page: React.ReactElement) => ( ); export const getDashboardLayout = (page: React.ReactElement) => ( -
+
( }, ]} > - {page} + + {page} + + + + -
-
+ ); export const getAdminLayout = (page: React.ReactElement) => ( diff --git a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx index 35d83a75..df18735d 100644 --- a/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx +++ b/website/src/components/LeaderboardGridCell/LeaderboardGridCell.tsx @@ -8,7 +8,7 @@ import useSWRImmutable from "swr/immutable"; const columns = [ { Header: "Rank", - accessor: (item: LeaderboardEntity, rowIndex: number) => "#" + (item.user_rank + 1), + accessor: "rank", style: { width: "90px" }, }, { diff --git a/website/src/components/Roadmap.tsx b/website/src/components/Roadmap.tsx index 283a4c8f..a29e1984 100644 --- a/website/src/components/Roadmap.tsx +++ b/website/src/components/Roadmap.tsx @@ -51,7 +51,7 @@ const Roadmap = () => {

Growing Up

    -
  • Third-Party Extentions
  • +
  • Third-Party Extensions
  • Device Control
  • Multi-Modality
diff --git a/website/src/components/RoleSelect.tsx b/website/src/components/RoleSelect.tsx new file mode 100644 index 00000000..d39d3868 --- /dev/null +++ b/website/src/components/RoleSelect.tsx @@ -0,0 +1,25 @@ +import { Select, SelectProps } from "@chakra-ui/react"; +import { forwardRef } from "react"; +import { ElementOf } from "src/types/utils"; + +export const roles = ["general", "admin", "banned"] as const; +export type Role = ElementOf; + +type RoleSelectProps = Omit & { + defaultValue?: Role; + value?: Role; +}; + +export const RoleSelect = forwardRef((props, ref) => { + return ( + + ); +}); + +RoleSelect.displayName = "RoleSelect"; diff --git a/website/src/components/SideMenuLayout.tsx b/website/src/components/SideMenuLayout.tsx index a768bc85..1e92a0c7 100644 --- a/website/src/components/SideMenuLayout.tsx +++ b/website/src/components/SideMenuLayout.tsx @@ -12,11 +12,11 @@ export const SideMenuLayout = (props: SideMenuLayoutProps) => { return ( - - + + - {props.children} + {props.children} ); diff --git a/website/src/components/Survey/LabelRadioGroup.tsx b/website/src/components/Survey/LabelRadioGroup.tsx index 617b05b9..bf2521f6 100644 --- a/website/src/components/Survey/LabelRadioGroup.tsx +++ b/website/src/components/Survey/LabelRadioGroup.tsx @@ -1,4 +1,18 @@ -import { Box, Button, Flex, useColorMode } from "@chakra-ui/react"; +import { + Box, + Button, + Flex, + IconButton, + Popover, + PopoverArrow, + PopoverBody, + PopoverCloseButton, + PopoverContent, + PopoverTrigger, + Text, + useColorMode, +} from "@chakra-ui/react"; +import { InformationCircleIcon } from "@heroicons/react/20/solid"; import { useId, useState } from "react"; import { colors } from "src/styles/Theme/colors"; @@ -8,6 +22,17 @@ interface LabelRadioGroupProps { isEditable?: boolean; } +const label_messages: { [label: string]: { description: string; explanation: string[] } } = { + spam: { + description: "The message is spam?", + explanation: [ + 'We consider the following unwanted content as spam: trolling, intentional undermining of our purpose, illegal material, material that violates our code of conduct, and other things that are inappropriate for our dataset. We collect these under the common heading of "spam".', + "This is not an assessment of whether this message is the best possible answer. Especially for prompts or user-replies, we very much want to retain all kinds of responses in the dataset, so that the assistant can learn to reply appropriately.", + "Please mark this text as spam only if it is clearly unsuited to be part of our dataset, as outlined above, and try not to make any subjective value-judgments beyond that.", + ], + }, +}; + export const LabelRadioGroup = (props: LabelRadioGroupProps) => { const [labelValues, setLabelValues] = useState(Array.from({ length: props.labelIDs.length }).map(() => 0)); const [interactionFlag, setInteractionFlag] = useState(false); @@ -17,7 +42,7 @@ export const LabelRadioGroup = (props: LabelRadioGroupProps) => { {props.labelIDs.map((labelId, idx) => ( { const newState = labelValues.slice(); @@ -45,7 +70,7 @@ interface ButtonState { } interface LabelRadioItemProps { - labelId: string; + labelText: { description: string; explanation?: string[] }; labelValue: number; clickHandler: (newVal: number) => unknown; states: ButtonState[]; @@ -63,7 +88,27 @@ const LabelRadioItem = (props: LabelRadioItemProps) => { {props.states.map((item, idx) => ( diff --git a/website/src/components/Tasks/LabelTask.tsx b/website/src/components/Tasks/LabelTask.tsx index 758ee127..636e9cde 100644 --- a/website/src/components/Tasks/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask.tsx @@ -66,7 +66,7 @@ export const LabelTask = ({ )} - {valid_labels.length === 1 ? ( + {task.mode === "simple" ? ( ) : ( diff --git a/website/src/components/Tasks/Task.tsx b/website/src/components/Tasks/Task.tsx index 0c728a6d..05410d4e 100644 --- a/website/src/components/Tasks/Task.tsx +++ b/website/src/components/Tasks/Task.tsx @@ -27,7 +27,7 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { const replyContent = useRef(null); const [showUnchangedWarning, setShowUnchangedWarning] = useState(false); - const taskType = TaskTypes.find((taskType) => taskType.type === task.type); + const taskType = TaskTypes.find((taskType) => taskType.type === task.type && taskType.mode === task.mode); const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, { onSuccess: async () => { diff --git a/website/src/components/Tasks/TaskTypes.tsx b/website/src/components/Tasks/TaskTypes.tsx index 6f0bd2e1..4a85ccbd 100644 --- a/website/src/components/Tasks/TaskTypes.tsx +++ b/website/src/components/Tasks/TaskTypes.tsx @@ -11,6 +11,7 @@ export interface TaskInfo { category: TaskCategory; pathname: string; type: string; + mode?: string; overview?: string; instruction?: string; update_type: string; @@ -90,7 +91,7 @@ export const TaskTypes: TaskInfo[] = [ unchanged_title: "Order Unchanged", unchanged_message: "You have not changed the order of the prompts. Are you sure you would like to continue?", }, - // label + // label (full) { label: "Label Initial Prompt", desc: "Provide labels for a prompt.", @@ -98,6 +99,7 @@ export const TaskTypes: TaskInfo[] = [ pathname: "/label/label_initial_prompt", overview: "Provide labels for the following prompt", type: "label_initial_prompt", + mode: "full", update_type: "text_labels", }, { @@ -105,8 +107,9 @@ export const TaskTypes: TaskInfo[] = [ desc: "Provide labels for a prompt.", category: TaskCategory.Label, pathname: "/label/label_prompter_reply", - overview: "Given the following discussion, provide labels for the final promp", + overview: "Given the following discussion, provide labels for the final prompt", type: "label_prompter_reply", + mode: "full", update_type: "text_labels", }, { @@ -116,6 +119,38 @@ export const TaskTypes: TaskInfo[] = [ pathname: "/label/label_assistant_reply", overview: "Given the following discussion, provide labels for the final prompt.", type: "label_assistant_reply", + mode: "full", + update_type: "text_labels", + }, + // label (simple) + { + label: "Classify Initial Prompt", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_initial_prompt", + overview: "Read the following prompt and then answer the question about it.", + type: "label_initial_prompt", + mode: "simple", + update_type: "text_labels", + }, + { + label: "Classify Prompter Reply", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_prompter_reply", + overview: "Read the following conversation and then answer the question about the last prompt in the discussion.", + type: "label_prompter_reply", + mode: "simple", + update_type: "text_labels", + }, + { + label: "Classify Assistant Reply", + desc: "Provide labels for a prompt.", + category: TaskCategory.Label, + pathname: "/label/label_assistant_reply", + overview: "Read the following conversation and then answer the question about the last prompt in the discussion.", + type: "label_assistant_reply", + mode: "simple", update_type: "text_labels", }, ]; diff --git a/website/src/lib/auth.ts b/website/src/lib/auth.ts index 803550a5..42c6cf79 100644 --- a/website/src/lib/auth.ts +++ b/website/src/lib/auth.ts @@ -1,11 +1,12 @@ import type { NextApiRequest, NextApiResponse } from "next"; import { getToken, JWT } from "next-auth/jwt"; +import { Role } from "src/components/RoleSelect"; /** * Wraps any API Route handler and verifies that the user does not have the * specified role. Returns a 403 if they do, otherwise runs the handler. */ -const withoutRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse, arg2: JWT) => void) => { +const withoutRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse, arg2: JWT) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role === role) { @@ -20,7 +21,7 @@ const withoutRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApi * Wraps any API Route handler and verifies that the user has the appropriate * role before running the handler. Returns a 403 otherwise. */ -const withRole = (role: string, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { +const withRole = (role: Role, handler: (arg0: NextApiRequest, arg1: NextApiResponse) => void) => { return async (req: NextApiRequest, res: NextApiResponse) => { const token = await getToken({ req }); if (!token || token.role !== role) { diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index f94d8b10..c2100c9b 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -170,7 +170,7 @@ export class OasstApiClient { const params = new URLSearchParams(); params.append("max_count", max_count.toString()); - // The backend API uses different query paramters depending on the + // The backend API uses different query parameters depending on the // pagination direction but they both take the same cursor value. // Depending on direction, pick the right query param. if (cursor !== "") { diff --git a/website/src/pages/404.tsx b/website/src/pages/404.tsx index c3d01510..f4c09bbf 100644 --- a/website/src/pages/404.tsx +++ b/website/src/pages/404.tsx @@ -6,9 +6,6 @@ import { PageEmptyState } from "src/components/EmptyState"; import { getTransparentHeaderLayout } from "src/components/Layout"; function Error() { - const router = useRouter(); - const backgroundColor = useColorModeValue("white", "gray.800"); - return ( <> diff --git a/website/src/pages/admin/manage_user/[id].tsx b/website/src/pages/admin/manage_user/[id].tsx index 90698f4a..88bfced4 100644 --- a/website/src/pages/admin/manage_user/[id].tsx +++ b/website/src/pages/admin/manage_user/[id].tsx @@ -1,17 +1,28 @@ -import { Button, Container, FormControl, FormLabel, Input, Select, Stack, useToast } from "@chakra-ui/react"; -import { Field, Form, Formik } from "formik"; +import { Button, Card, CardBody, Container, FormControl, FormLabel, Input, Stack, useToast } from "@chakra-ui/react"; +import { InferGetServerSidePropsType } from "next"; import Head from "next/head"; import { useRouter } from "next/router"; import { useSession } from "next-auth/react"; import { useEffect } from "react"; +import { useForm } from "react-hook-form"; import { getAdminLayout } from "src/components/Layout"; +import { Role, RoleSelect } from "src/components/RoleSelect"; import { UserMessagesCell } from "src/components/UserMessagesCell"; import { post } from "src/lib/api"; import { oasstApiClient } from "src/lib/oasst_api_client"; import prisma from "src/lib/prismadb"; import useSWRMutation from "swr/mutation"; -const ManageUser = ({ user }) => { +interface UserForm { + user_id: string; + id: string; + auth_method: string; + display_name: string; + role: Role; + notes: string; +} + +const ManageUser = ({ user }: InferGetServerSidePropsType) => { const toast = useToast(); const router = useRouter(); const { data: session, status } = useSession(); @@ -51,6 +62,10 @@ const ManageUser = ({ user }) => { }, }); + const { register, handleSubmit } = useForm({ + defaultValues: user, + }); + return ( <> @@ -61,50 +76,31 @@ const ManageUser = ({ user }) => { /> - - { - trigger(values); - }} - > -
- - - - - {({ field }) => ( - - Display Name - - - )} - - - {({ field }) => ( - - Role - - - )} - - - {({ field }) => ( - - Notes - - - )} - - - -
+ + + +
trigger(data))}> + + + + + Display Name + + + + Role + + + + Notes + + + +
+
+
@@ -125,7 +121,7 @@ export async function getServerSideProps({ query }) { }); const user = { ...backend_user, - role: local_user?.role || "general", + role: (local_user?.role || "general") as Role, }; return { props: { diff --git a/website/src/pages/api/auth/[...nextauth].ts b/website/src/pages/api/auth/[...nextauth].ts index 412551fe..691cbcba 100644 --- a/website/src/pages/api/auth/[...nextauth].ts +++ b/website/src/pages/api/auth/[...nextauth].ts @@ -2,12 +2,13 @@ import { PrismaAdapter } from "@next-auth/prisma-adapter"; import { boolean } from "boolean"; import type { AuthOptions } from "next-auth"; import NextAuth from "next-auth"; +import { Provider } from "next-auth/providers"; import CredentialsProvider from "next-auth/providers/credentials"; import DiscordProvider from "next-auth/providers/discord"; import EmailProvider from "next-auth/providers/email"; import prisma from "src/lib/prismadb"; -const providers = []; +const providers: Provider[] = []; // Register an email magic link auth method. providers.push( @@ -39,12 +40,13 @@ if (boolean(process.env.DEBUG_LOGIN) || process.env.NODE_ENV === "development") name: "Debug Credentials", credentials: { username: { label: "Username", type: "text" }, + role: { label: "Role", type: "text" }, }, async authorize(credentials) { const user = { id: credentials.username, name: credentials.username, - role: "admin", + role: credentials.role, }; // save the user to the database await prisma.user.upsert({ diff --git a/website/src/pages/api/leaderboard.ts b/website/src/pages/api/leaderboard.ts index 10a9a3a7..592f3da5 100644 --- a/website/src/pages/api/leaderboard.ts +++ b/website/src/pages/api/leaderboard.ts @@ -6,7 +6,7 @@ import { LeaderboardTimeFrame } from "src/types/Leaderboard"; * Returns the set of valid labels that can be applied to messages. */ const handler = withoutRole("banned", async (req, res) => { - const time_frame = req.query.time_frame as LeaderboardTimeFrame; + const time_frame = (req.query.time_frame as LeaderboardTimeFrame) || LeaderboardTimeFrame.day; const { leaderboard } = await oasstApiClient.fetch_leaderboard(time_frame); res.status(200).json(leaderboard); }); diff --git a/website/src/pages/auth/signin.tsx b/website/src/pages/auth/signin.tsx index c2f05e69..bd442a65 100644 --- a/website/src/pages/auth/signin.tsx +++ b/website/src/pages/auth/signin.tsx @@ -1,14 +1,16 @@ -import { Button, Input, Stack } from "@chakra-ui/react"; +import { Button, ButtonProps, Input, Stack, useColorModeValue } from "@chakra-ui/react"; import { useColorMode } from "@chakra-ui/react"; +import { GetServerSideProps } from "next"; import Head from "next/head"; import Link from "next/link"; import { useRouter } from "next/router"; -import { getCsrfToken, getProviders, signIn } from "next-auth/react"; +import { ClientSafeProvider, getProviders, signIn } from "next-auth/react"; import React, { useEffect, useRef, useState } from "react"; import { FaBug, FaDiscord, FaEnvelope, FaGithub } from "react-icons/fa"; import { AuthLayout } from "src/components/AuthLayout"; import { Footer } from "src/components/Footer"; import { Header } from "src/components/Header"; +import { RoleSelect } from "src/components/RoleSelect"; export type SignInErrorTypes = | "Signin" @@ -37,8 +39,11 @@ const errorMessages: Record = { default: "Unable to sign in.", }; -// eslint-disable-next-line @typescript-eslint/no-unused-vars -function Signin({ csrfToken, providers }) { +interface SigninProps { + providers: Awaited>; +} + +function Signin({ providers }: SigninProps) { const router = useRouter(); const { discord, email, github, credentials } = providers; const emailEl = useRef(null); @@ -60,18 +65,10 @@ function Signin({ csrfToken, providers }) { signIn(email.id, { callbackUrl: "/dashboard", email: emailEl.current.value }); }; - const debugUsernameEl = useRef(null); - function signinWithDebugCredentials(ev: React.FormEvent) { - ev.preventDefault(); - signIn(credentials.id, { callbackUrl: "/dashboard", username: debugUsernameEl.current.value }); - } - const { colorMode } = useColorMode(); const bgColorClass = colorMode === "light" ? "bg-gray-50" : "bg-chakra-gray-900"; const buttonBgColor = colorMode === "light" ? "#2563eb" : "#2563eb"; - const buttonColorScheme = colorMode === "light" ? "blue" : "dark-blue-btn"; - return (
@@ -80,17 +77,7 @@ function Signin({ csrfToken, providers }) { - {credentials && ( -
- For Debugging Only - - - - -
- )} + {credentials && } {email && (
@@ -102,16 +89,9 @@ function Signin({ csrfToken, providers }) { placeholder="Email Address" ref={emailEl} /> - +
)} @@ -179,13 +159,49 @@ Signin.getLayout = (page) => ( export default Signin; -export async function getServerSideProps() { - const csrfToken = await getCsrfToken(); +const SigninButton = (props: ButtonProps) => { + const buttonColorScheme = useColorModeValue("blue", "dark-blue-btn"); + + return ( + + ); +}; + +const DebugSigninForm = ({ credentials, bgColorClass }: { credentials: ClientSafeProvider; bgColorClass: string }) => { + const debugUsernameEl = useRef(null); + const roleRef = useRef(null); + function signinWithDebugCredentials(ev: React.FormEvent) { + ev.preventDefault(); + signIn(credentials.id, { + callbackUrl: "/dashboard", + username: debugUsernameEl.current.value, + role: roleRef.current.value, + }); + } + return ( +
+ For Debugging Only + + + + }>Continue with Debug User + +
+ ); +}; + +export const getServerSideProps: GetServerSideProps = async () => { const providers = await getProviders(); return { props: { - csrfToken, providers, }, }; -} +}; diff --git a/website/src/pages/create/initial_prompt.tsx b/website/src/pages/create/initial_prompt.tsx index dc4429d0..6a51ca25 100644 --- a/website/src/pages/create/initial_prompt.tsx +++ b/website/src/pages/create/initial_prompt.tsx @@ -19,8 +19,8 @@ const InitialPrompt = () => { return ( <> - Reply as Assistant - + Initial Prompt + diff --git a/website/src/pages/create/user_reply.tsx b/website/src/pages/create/user_reply.tsx index b28b7442..8d2981e5 100644 --- a/website/src/pages/create/user_reply.tsx +++ b/website/src/pages/create/user_reply.tsx @@ -19,8 +19,8 @@ const UserReply = () => { return ( <> - Reply as Assistant - + Reply as User + diff --git a/website/src/pages/dashboard.tsx b/website/src/pages/dashboard.tsx index 42481f1e..3971e125 100644 --- a/website/src/pages/dashboard.tsx +++ b/website/src/pages/dashboard.tsx @@ -1,3 +1,4 @@ +import { Flex } from "@chakra-ui/react"; import Head from "next/head"; import { useSession } from "next-auth/react"; import { LeaderboardTable, TaskOption } from "src/components/Dashboard"; @@ -16,8 +17,10 @@ const Dashboard = () => { Dashboard - Open Assistant - - + + + + ); }; diff --git a/website/src/pages/leaderboard.tsx b/website/src/pages/leaderboard.tsx index 131778c9..e53b0c52 100644 --- a/website/src/pages/leaderboard.tsx +++ b/website/src/pages/leaderboard.tsx @@ -1,4 +1,4 @@ -import { Box, Heading, Tabs, TabList, TabPanels, Tab, TabPanel } from "@chakra-ui/react"; +import { Box, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from "@chakra-ui/react"; import Head from "next/head"; import { getDashboardLayout } from "src/components/Layout"; import { LeaderboardGridCell } from "src/components/LeaderboardGridCell"; diff --git a/website/src/pages/privacy-policy.tsx b/website/src/pages/privacy-policy.tsx index cac820fb..1c94b669 100644 --- a/website/src/pages/privacy-policy.tsx +++ b/website/src/pages/privacy-policy.tsx @@ -88,7 +88,7 @@ const PrivacyPolicy = () => { { number: "4", title: "Inquiries", - desc: "When you contact us via e-mail, telephone or telefax, your inquiry, including all personal data arising thereof will be stored by us for the purpose of processing your request. We will not pass on these data without your consent. The processing of these data is based on Article 6 (1) (1) (b) GDPR, if your inquiry is related to the fulfilment of a contract concluded with us or required for the implementation of pre-contractual measures. Furthermore, the processing is based on Article 6 (1) (1) (f) GDPR, because we have a legitimate interest in the effective handling of requests sent to us. In addition, according to Article 6 (1) (1) (c) GDPR we are also entitled to the processing of the above-mentioned data, because we are legally bound to enable fast electronic contact and immediate communication. Of course, your data will only be used strictly according to purpose and only for processing and responding to your request. After final processing, your data will immediately be anonymized or deleted, unless we are bound by a legally prescribed storage period.", + desc: "When you contact us via e-mail, telephone or telefax, your inquiry, including all personal data arising thereof will be stored by us for the purpose of processing your request. We will not pass on these data without your consent. The processing of these data is based on Article 6 (1) (1) (b) GDPR, if your inquiry is related to the fulfillment of a contract concluded with us or required for the implementation of pre-contractual measures. Furthermore, the processing is based on Article 6 (1) (1) (f) GDPR, because we have a legitimate interest in the effective handling of requests sent to us. In addition, according to Article 6 (1) (1) (c) GDPR we are also entitled to the processing of the above-mentioned data, because we are legally bound to enable fast electronic contact and immediate communication. Of course, your data will only be used strictly according to purpose and only for processing and responding to your request. After final processing, your data will immediately be anonymized or deleted, unless we are bound by a legally prescribed storage period.", sections: [], }, { diff --git a/website/src/pages/terms-of-service.tsx b/website/src/pages/terms-of-service.tsx index f58a9084..b0e298ba 100644 --- a/website/src/pages/terms-of-service.tsx +++ b/website/src/pages/terms-of-service.tsx @@ -14,7 +14,7 @@ const TermsOfService = () => { { number: "1.1", title: "", - desc: `LAION (association in formation), Marie-Henning-Weg 143, 21035 Hamburg (hereinafter referred to as: "LAION") operates an online portal for the producing a machine learning model called Open Assistant using crowdsourced data.`, + desc: `LAION (association in formation), Marie-Henning-Weg 143, 21035 Hamburg (hereinafter referred to as: "LAION") operates an online portal for the producing a machine learning model called Open Assistant using crowd-sourced data.`, }, { number: "1.2", diff --git a/website/src/styles/Home.module.css b/website/src/styles/Home.module.css deleted file mode 100644 index e69de29b..00000000 diff --git a/website/src/styles/Theme/components/Card.ts b/website/src/styles/Theme/components/Card.ts new file mode 100644 index 00000000..8cd66031 --- /dev/null +++ b/website/src/styles/Theme/components/Card.ts @@ -0,0 +1,27 @@ +import { cardAnatomy } from "@chakra-ui/anatomy"; +import { createMultiStyleConfigHelpers } from "@chakra-ui/react"; + +const { definePartsStyle, defineMultiStyleConfig } = createMultiStyleConfigHelpers(cardAnatomy.keys); + +export const cardTheme = defineMultiStyleConfig({ + baseStyle: definePartsStyle(({ colorMode }) => { + const isLightMode = colorMode === "light"; + return { + container: { + backgroundColor: isLightMode ? "white" : "gray.700", + }, + header: {}, + body: { + padding: 6, + }, + footer: {}, + }; + }), + variants: { + elevated: definePartsStyle({ + container: { + borderRadius: "xl", + }, + }), + }, +}); diff --git a/website/src/styles/Theme/index.ts b/website/src/styles/Theme/index.ts index 718eab7f..37ca424d 100644 --- a/website/src/styles/Theme/index.ts +++ b/website/src/styles/Theme/index.ts @@ -2,6 +2,7 @@ import { type ThemeConfig, extendTheme } from "@chakra-ui/react"; import { Styles } from "@chakra-ui/theme-tools"; import { colors } from "./colors"; +import { cardTheme } from "./components/Card"; import { containerTheme } from "./components/Container"; const config: ThemeConfig = { @@ -12,6 +13,7 @@ const config: ThemeConfig = { const components = { Container: containerTheme, + Card: cardTheme, }; const breakpoints = { diff --git a/website/src/types/Leaderboard.ts b/website/src/types/Leaderboard.ts index 2e72fafc..21c91766 100644 --- a/website/src/types/Leaderboard.ts +++ b/website/src/types/Leaderboard.ts @@ -16,7 +16,7 @@ export interface LeaderboardReply { } export interface LeaderboardEntity { - user_rank: number; + rank: number; user_id: string; username: string; auth_method: string; diff --git a/website/src/types/utils.ts b/website/src/types/utils.ts new file mode 100644 index 00000000..82c35036 --- /dev/null +++ b/website/src/types/utils.ts @@ -0,0 +1,3 @@ +// https://github.com/ts-essentials/ts-essentials/blob/25cae45c162f8784e3cdae8f43783d0c66370a57/lib/types.ts#L437 +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export type ElementOf = T extends readonly (infer ET)[] ? ET : never;