From c8d16285d016822b0b03f821b04e099f1b4c38c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 15:05:46 +0100 Subject: [PATCH 01/13] Import message trees from jsonl file (#964) * add new backlog_ranking tree state * add first version of import script * allow activation of trees during import * add min_active_rankings_per_lang config param * add settings docstring --- ...add_origin_column_to_message_tree_state.py | 31 +++ backend/import.py | 187 ++++++++++++++++++ backend/main.py | 4 +- backend/oasst_backend/api/deps.py | 6 +- backend/oasst_backend/config.py | 9 + backend/oasst_backend/models/message.py | 5 + .../models/message_tree_state.py | 5 + backend/oasst_backend/prompt_repository.py | 6 +- backend/oasst_backend/tree_manager.py | 108 +++++++--- backend/oasst_backend/user_repository.py | 43 +++- .../oasst_backend/user_stats_repository.py | 4 + backend/oasst_backend/utils/tree_export.py | 18 +- 12 files changed, 377 insertions(+), 49 deletions(-) create mode 100644 backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py create mode 100644 backend/import.py diff --git a/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py b/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py new file mode 100644 index 00000000..e11f8f9e --- /dev/null +++ b/backend/alembic/versions/2023_01_28_1157-49d8445b4c90_add_origin_column_to_message_tree_state.py @@ -0,0 +1,31 @@ +"""add origin column to message_tree_state + +Revision ID: 49d8445b4c90 +Revises: f856bf19d32b +Create Date: 2023-01-28 11:57:45.580027 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "49d8445b4c90" +down_revision = "f856bf19d32b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message", sa.Column("synthetic", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + op.add_column("message", sa.Column("model_name", sa.String(length=1024), nullable=True)) + op.add_column("message_tree_state", sa.Column("origin", sa.String(length=1024), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("message_tree_state", "origin") + op.drop_column("message", "model_name") + op.drop_column("message", "synthetic") + # ### end Alembic commands ### diff --git a/backend/import.py b/backend/import.py new file mode 100644 index 00000000..46a9c2cb --- /dev/null +++ b/backend/import.py @@ -0,0 +1,187 @@ +import argparse +import json +from pathlib import Path +from typing import Optional +from uuid import UUID + +import oasst_backend.models.db_payload as db_payload +import oasst_backend.utils.database_utils as db_utils +import pydantic +from loguru import logger +from oasst_backend.api.deps import create_api_client +from oasst_backend.models import ApiClient, Message +from oasst_backend.models.message_tree_state import MessageTreeState +from oasst_backend.models.message_tree_state import State as TreeState +from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.tree_export import ExportMessageNode, ExportMessageTree +from sqlmodel import Session + +# well known id +IMPORT_API_CLIENT_ID = UUID("bd8fde8b-1d8e-4e9a-9966-e96d000f8363") + + +class Importer: + def __init__(self, db: Session, origin: str, model_name: Optional[str] = None): + self.db = db + self.origin = origin + self.model_name = model_name + + # get import api client + api_client = db.query(ApiClient).filter(ApiClient.id == IMPORT_API_CLIENT_ID).first() + if not api_client: + api_client = create_api_client( + session=db, + description="API client used for importing data", + frontend_type="import", + force_id=IMPORT_API_CLIENT_ID, + ) + + ur = UserRepository(db, api_client) + self.import_user = ur.lookup_system_user(username="import") + self.pr = PromptRepository(db=db, api_client=api_client, user_repository=ur) + self.api_client = api_client + + def fetch_message_tree_state(self, message_tree_id: UUID) -> MessageTreeState: + return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one_or_none() + + def import_message( + self, message: ExportMessageNode, message_tree_id: UUID, parent_id: Optional[UUID] = None + ) -> Message: + payload = db_payload.MessagePayload(text=message.text) + msg = Message( + id=message.message_id, + message_tree_id=message_tree_id, + frontend_message_id=message.message_id, + parent_id=parent_id, + review_count=message.review_count or 0, + lang=message.lang or "en", + review_result=True, + synthetic=message.synthetic if message.synthetic is not None else True, + model_name=message.model_name or self.model_name, + role=message.role, + api_client_id=self.api_client.id, + payload_type=type(payload).__name__, + payload=PayloadContainer(payload=payload), + user_id=self.import_user.id, + ) + self.db.add(msg) + if message.replies: + for r in message.replies: + self.import_message(r, message_tree_id=message_tree_id, parent_id=msg.id) + self.db.flush() + if parent_id is None: + self.pr.update_children_counts(msg.id) + self.db.refresh(msg) + return msg + + def import_tree( + self, tree: ExportMessageTree, state: TreeState = TreeState.BACKLOG_RANKING + ) -> tuple[MessageTreeState, Message]: + assert tree.message_tree_id is not None and tree.message_tree_id == tree.prompt.message_id + root_msg = self.import_message(tree.prompt, message_tree_id=tree.prompt.message_id) + assert state == TreeState.BACKLOG_RANKING or state == TreeState.RANKING, f"{state} not supported for import" + active = state == TreeState.RANKING + mts = MessageTreeState( + message_tree_id=root_msg.id, + goal_tree_size=0, + max_depth=0, + max_children_count=0, + state=state, + origin=self.origin, + active=active, + ) + self.db.add(mts) + return mts, root_msg + + +def import_file( + input_file_path: Path, + origin: str, + *, + model_name: Optional[str] = None, + num_activate: int = 0, + max_count: Optional[int] = None, + dry_run: bool = False, +) -> int: + @db_utils.managed_tx_function(auto_commit=db_utils.CommitMode.ROLLBACK if dry_run else db_utils.CommitMode.COMMIT) + def import_tx(db: Session) -> int: + importer = Importer(db, origin=origin, model_name=model_name) + i = 0 + with input_file_path.open() as file_in: + # read line tree object + for line in file_in: + dict_tree = json.loads(line) + + # validate data + tree: ExportMessageTree = pydantic.parse_obj_as(ExportMessageTree, dict_tree) + existing_mts = importer.fetch_message_tree_state(tree.message_tree_id) + if existing_mts: + logger.info(f"Skipping existing message tree: {tree.message_tree_id}") + else: + state = TreeState.BACKLOG_RANKING if i >= num_activate else TreeState.RANKING + mts, root_msg = importer.import_tree(tree, state=state) + i += 1 + logger.info( + f"imported tree: {mts.message_tree_id}, {mts.state=}, {mts.active=}, {root_msg.children_count=}" + ) + + if max_count and i >= max_count: + logger.info(f"Reached max count {max_count} of trees to import.") + break + return i + + if dry_run: + logger.info("DRY RUN with rollback") + return import_tx() + + +def parse_args(): + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + parser = argparse.ArgumentParser() + parser.add_argument( + "input_file_path", + help="Input file path", + ) + parser.add_argument("--origin", type=str, default=None, help="Value for origin of message trees") + parser.add_argument("--model_name", type=str, default=None, help="Default name of model (if missing in messages)") + parser.add_argument("--num_activate", type=int, default=0, help="Number of trees to add in ranking state") + parser.add_argument("--max_count", type=int, default=None, help="Maximum number of message trees to import") + parser.add_argument("--dry_run", type=str2bool, default=False) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + input_file_path = Path(args.input_file_path) + if not input_file_path.exists() or not input_file_path.is_file(): + print("Invalid input file:", args.input_file_path) + exit(1) + + dry_run = args.dry_run + num_imported = import_file( + input_file_path, + origin=args.origin or input_file_path.name, + model_name=args.model_name, + num_activate=args.num_activate, + max_count=args.max_count, + dry_run=dry_run, + ) + + logger.info(f"Done ({num_imported=}, {dry_run=})") + + +if __name__ == "__main__": + main() diff --git a/backend/main.py b/backend/main.py index ea4b25da..06c67a2e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -191,6 +191,7 @@ if settings.DEBUG_USE_SEED_DATA: review_count=5, review_result=True, check_tree_state=False, + check_duplicate=False, ) if message.parent_id is None: tm._insert_default_state( @@ -215,7 +216,8 @@ def ensure_tree_states(): try: logger.info("Startup: TreeManager.ensure_tree_states()") with Session(engine) as db: - tm = TreeManager(db, None) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) + tm = TreeManager(db, PromptRepository(db, api_client=api_client)) tm.ensure_tree_states() except Exception: diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index dd10bd20..af29285a 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -1,6 +1,7 @@ from http import HTTPStatus from secrets import token_hex -from typing import Generator, NamedTuple +from typing import Generator, NamedTuple, Optional +from uuid import UUID from fastapi import Depends, Request, Response, Security from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -67,6 +68,7 @@ def create_api_client( trusted: bool | None = False, admin_email: str | None = None, api_key: str | None = None, + force_id: Optional[UUID] = None, ) -> ApiClient: if api_key is None: api_key = token_hex(32) @@ -79,6 +81,8 @@ def create_api_client( trusted=trusted, admin_email=admin_email, ) + if force_id: + api_client.id = force_id session.add(api_client) session.commit() session.refresh(api_client) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 92316037..5c566b1e 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -46,6 +46,15 @@ class TreeManagerConfiguration(BaseModel): num_required_rankings: int = 3 """Number of rankings in which the message participated.""" + p_activate_backlog_tree: float = 0.8 + """Probability to activate a message tree in BACKLOG_RANKING state when another tree enters + a terminal state. Use this settting to control ratio of initial prompts and backlog tree + activations.""" + + min_active_rankings_per_lang: int = 2 + """When the number of active ranking tasks is below this value when a tree enters a terminal + state an available trees in BACKLOG_RANKING will be actived (i.e. enters the RANKING state).""" + labels_initial_prompt: list[TextLabel] = [ TextLabel.spam, TextLabel.quality, diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index 24fafc01..2cf3b7bb 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -57,6 +57,11 @@ class Message(SQLModel, table=True): rank: Optional[int] = Field(nullable=True) + synthetic: Optional[bool] = Field( + sa_column=sa.Column(sa.Boolean, default=False, server_default=false(), nullable=False) + ) + model_name: Optional[str] = Field(sa_column=sa.Column(sa.String(1024), nullable=True)) + emojis: Optional[dict[str, int]] = Field(default=None, sa_column=sa.Column(pg.JSONB), nullable=False) _user_emojis: Optional[list[str]] = PrivateAttr(default=None) diff --git a/backend/oasst_backend/models/message_tree_state.py b/backend/oasst_backend/models/message_tree_state.py index 2f7ce363..00b94967 100644 --- a/backend/oasst_backend/models/message_tree_state.py +++ b/backend/oasst_backend/models/message_tree_state.py @@ -43,6 +43,9 @@ class State(str, Enum): HALTED_BY_MODERATOR = "halted_by_moderator" """A moderator decided to manually halt the message tree construction process.""" + BACKLOG_RANKING = "backlog_ranking" + """Imported tree ready to be activated and ranked by users (currently inactive).""" + VALID_STATES = ( State.INITIAL_PROMPT_REVIEW, @@ -51,6 +54,7 @@ VALID_STATES = ( State.READY_FOR_SCORING, State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, + State.BACKLOG_RANKING, ) TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR) @@ -67,3 +71,4 @@ class MessageTreeState(SQLModel, table=True): max_children_count: int = Field(nullable=False) state: str = Field(nullable=False, max_length=128, index=True) active: bool = Field(nullable=False, index=True) + origin: str = Field(sa_column=sa.Column(sa.String(1024), nullable=True)) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 01f63d26..c69f7340 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -177,6 +177,7 @@ class PromptRepository: review_count: int = 0, review_result: bool = False, check_tree_state: bool = True, + check_duplicate: bool = True, ) -> Message: self.ensure_user_is_enabled() @@ -199,7 +200,7 @@ class PromptRepository: logger.error(f"Message size {len(text)=} exceeds size limit of {settings.MESSAGE_SIZE_LIMIT=}.") raise OasstError("Message size too long.", OasstErrorCode.TASK_MESSAGE_TOO_LONG) - if self.check_users_recent_replies_for_duplicates(text): + if check_duplicate and self.check_users_recent_replies_for_duplicates(text): raise OasstError("User recent messages have duplicates", OasstErrorCode.TASK_MESSAGE_DUPLICATED) if task.parent_message_id: @@ -909,8 +910,7 @@ FROM ( ) AS cc WHERE message.id = cc.id; """ - r = self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id}) - logger.debug(f"update_children_count({message_tree_id=}): {r.rowcount} rows.") + self.db.execute(text(sql_update_children_count), {"message_tree_id": message_tree_id}) @managed_tx_method(CommitMode.COMMIT) def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 782b9af8..bcab02c3 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -25,7 +25,7 @@ from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingM from oasst_backend.utils.ranking import ranked_pairs from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema -from sqlmodel import Session, func, not_, text, update +from sqlmodel import Session, func, not_, or_, text, update class TaskType(Enum): @@ -73,6 +73,7 @@ class IncompleteRankingsRow(pydantic.BaseModel): role: str children_count: int child_min_ranking_count: int + message_tree_id: UUID class Config: orm_mode = True @@ -625,19 +626,28 @@ class TreeManager: return protocol_schema.TaskDone() - @managed_tx_method(CommitMode.FLUSH) def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State): - assert mts and mts.active + assert mts is_terminal = state in message_tree_state.TERMINAL_STATES - + was_active = mts.active if is_terminal: mts.active = False mts.state = state.value self.db.add(mts) + self.db.flush if is_terminal: logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})") + root_msg = self.pr.fetch_message(message_id=mts.message_tree_id, fail_if_missing=False) + if root_msg and was_active: + if random.random() < self.cfg.p_activate_backlog_tree: + self.activate_backlog_tree(lang=root_msg.lang) + + if self.cfg.min_active_rankings_per_lang > 0: + incomplete_rankings = self.query_incomplete_rankings(lang=root_msg.lang) + if len(incomplete_rankings) < self.cfg.min_active_rankings_per_lang: + self.activate_backlog_tree(lang=root_msg.lang) else: logger.info(f"Tree entered '{mts.state}' state ({mts.message_tree_id=})") @@ -680,24 +690,30 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.RANKING) return True - def check_condition_for_scoring_state( - self, message_tree_id: UUID - ) -> Tuple[bool, dict[UUID, list[MessageReaction]]]: + def check_condition_for_scoring_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_scoring_state({message_tree_id=})") mts = self.pr.fetch_tree_state(message_tree_id) - if not mts.active or mts.state != message_tree_state.State.RANKING: - logger.debug(f"False {mts.active=}, {mts.state=}") - return False, None + if mts.state != message_tree_state.State.SCORING_FAILED: + if not mts.active or mts.state not in ( + message_tree_state.State.RANKING, + message_tree_state.State.READY_FOR_SCORING, + ): + logger.debug(f"False {mts.active=}, {mts.state=}") + return False ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) for parent_msg_id, ranking in rankings_by_message.items(): if len(ranking) < self.cfg.num_required_rankings: logger.debug(f"False {parent_msg_id=} {len(ranking)=}") - return False, None + return False - self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) + if ( + mts.state != message_tree_state.State.SCORING_FAILED + and mts.state != message_tree_state.State.READY_FOR_SCORING + ): + self._enter_state(mts, message_tree_state.State.READY_FOR_SCORING) self.update_message_ranks(message_tree_id, rankings_by_message) return True @@ -759,8 +775,35 @@ class TreeManager: return False self._enter_state(mts, message_tree_state.State.READY_FOR_EXPORT) + return True + def activate_backlog_tree(self, lang: str) -> MessageTreeState: + while True: + # find tree in backlog state + backlog_tree: MessageTreeState = ( + self.db.query(MessageTreeState) + .join(Message, MessageTreeState.message_tree_id == Message.id) # root msg + .filter(MessageTreeState.state == message_tree_state.State.BACKLOG_RANKING) + .filter(Message.lang == lang) + .limit(1) + .one_or_none() + ) + + if not backlog_tree: + return None + + if len(self.query_tree_ranking_results(message_tree_id=backlog_tree.message_tree_id)) == 0: + logger.info( + f"Backlog tree {backlog_tree.message_tree_id} has no children to rank, aborting with 'aborted_low_grade' state." + ) + self._enter_state(backlog_tree, message_tree_state.State.ABORTED_LOW_GRADE) + else: + logger.info(f"Activating backlog tree {backlog_tree.message_tree_id}") + backlog_tree.active = True + self._enter_state(backlog_tree, message_tree_state.State.RANKING) + return backlog_tree + def _calculate_acceptance(self, labels: list[TextLabels]): # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) @@ -828,7 +871,8 @@ class TreeManager: _sql_find_incomplete_rankings = """ -- find incomplete rankings SELECT m.parent_id, m.role, COUNT(m.id) children_count, MIN(m.ranking_count) child_min_ranking_count, - COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings + COUNT(m.id) FILTER (WHERE m.ranking_count >= :num_required_rankings) as completed_rankings, + mts.message_tree_id FROM message_tree_state mts INNER JOIN message m ON mts.message_tree_id = m.message_tree_id WHERE mts.active -- only consider active trees @@ -837,7 +881,7 @@ WHERE mts.active -- only consider active trees AND m.lang = :lang -- matches lang AND NOT m.deleted -- not deleted AND m.parent_id IS NOT NULL -- ignore initial prompts -GROUP BY m.parent_id, m.role +GROUP BY m.parent_id, m.role, mts.message_tree_id HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings """ @@ -846,7 +890,8 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings WITH incomplete_rankings AS ({_sql_find_incomplete_rankings}) SELECT ir.* FROM incomplete_rankings ir LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload' -GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings +GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings, + ir.message_tree_id HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0) """ @@ -985,8 +1030,8 @@ SELECT p.parent_id, mr.* FROM GROUP BY m.parent_id, m.message_tree_id HAVING COUNT(m.id) > 1 ) as p -INNER JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload') -INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' +LEFT JOIN task t ON p.parent_id = t.parent_message_id AND t.done AND (t.payload_type = 'RankPrompterRepliesPayload' OR t.payload_type = 'RankAssistantRepliesPayload') +LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'RankingReactionPayload' """ def query_tree_ranking_results( @@ -1029,7 +1074,14 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki self._insert_default_state(id, state=state) rankings = ( - self.db.query(MessageTreeState).filter(MessageTreeState.state == message_tree_state.State.RANKING).all() + self.db.query(MessageTreeState) + .filter( + or_( + MessageTreeState.state == message_tree_state.State.RANKING, + MessageTreeState.state == message_tree_state.State.READY_FOR_SCORING, + ) + ) + .all() ) if len(rankings) > 0: logger.info(f"Checking state of {len(rankings)} message trees in ranking state.") @@ -1322,17 +1374,17 @@ DELETE FROM user_stats WHERE user_id = :user_id; @managed_tx_method(CommitMode.COMMIT) def retry_scoring_failed_message_trees(self): - query = self.db.query(MessageTreeState.message_tree_id).filter( + query = self.db.query(MessageTreeState).filter( MessageTreeState.state == message_tree_state.State.SCORING_FAILED ) - ranking_role_filter = None if self.cfg.rank_prompter_replies else "assistant" - for row in query.all(): + for mts in query.all(): + mts: MessageTreeState try: - message_tree_id = row["message_tree_id"] - rankings_by_message = self.query_tree_ranking_results(message_tree_id, role_filter=ranking_role_filter) - self.update_message_ranks(message_tree_id=message_tree_id, rankings_by_message=rankings_by_message) + if not self.check_condition_for_scoring_state(mts.message_tree_id): + mts.active = True + self._enter_state(message_tree_state.State.RANKING) except Exception: - logger.exception(f"retry_scoring_failed_message_trees failed for ({message_tree_id=})") + logger.exception(f"retry_scoring_failed_message_trees failed for ({mts.message_tree_id=})") if __name__ == "__main__": @@ -1366,8 +1418,8 @@ if __name__ == "__main__": # print("next_task:", tm.next_task()) - # print( - # ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("2ac20d38-6650-43aa-8bb3-f61080c0d921")) - # ) + print( + ".query_tree_ranking_results", tm.query_tree_ranking_results(UUID("21f9d585-d22c-44ab-a696-baa3d83b5f1b")) + ) # print(tm.export_trees_to_file(message_tree_ids=["7e75fb38-e664-4e2b-817c-b9a0b01b0074"], file="lol.jsonl")) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 6f96748f..8d6187fb 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -117,13 +117,20 @@ class UserRepository: self.db.add(user) @managed_tx_method(CommitMode.COMMIT) - def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + def _lookup_user_tx( + self, + *, + username: str, + auth_method: str, + display_name: Optional[str] = None, + create_missing: bool = True, + ) -> User | None: user: User = ( self.db.query(User) .filter( User.api_client_id == self.api_client.id, - User.username == client_user.id, - User.auth_method == client_user.auth_method, + User.username == username, + User.auth_method == auth_method, ) .first() ) @@ -131,30 +138,46 @@ class UserRepository: if create_missing: # user is unknown, create new record user = User( - username=client_user.id, - display_name=client_user.display_name, + username=username, + display_name=display_name, api_client_id=self.api_client.id, - auth_method=client_user.auth_method, + auth_method=auth_method, + show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user ) self.db.add(user) - elif client_user.display_name and client_user.display_name != user.display_name: + elif display_name and display_name != user.display_name: # we found the user but the display name changed - user.display_name = client_user.display_name + user.display_name = display_name self.db.add(user) + return user - def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None: if not client_user: return None num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT for i in range(num_retries): try: - return self._lookup_client_user_tx(client_user, create_missing) + return self._lookup_user_tx( + username=client_user.id, + auth_method=client_user.auth_method, + display_name=client_user.display_name, + create_missing=create_missing, + ) except IntegrityError: # catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username if i + 1 == num_retries: raise + @managed_tx_method(CommitMode.COMMIT) + def lookup_system_user(self, username: str, create_missing: bool = True) -> User | None: + return self._lookup_user_tx( + username=username, + auth_method="system", + display_name=f"__system__/{username}", + create_missing=create_missing, + ) + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None, diff --git a/backend/oasst_backend/user_stats_repository.py b/backend/oasst_backend/user_stats_repository.py index 81e5cb92..4cc0fc84 100644 --- a/backend/oasst_backend/user_stats_repository.py +++ b/backend/oasst_backend/user_stats_repository.py @@ -214,6 +214,10 @@ class UserStatsRepository: d = delete(UserStats).where(UserStats.time_frame == time_frame_key) self.session.execute(d) + if None in stats_by_user: + logger.warning("Some messages in DB have NULL values in user_id column.") + del stats_by_user[None] + # compute magic leader score for v in stats_by_user.values(): v.leader_score = v.compute_leader_score() diff --git a/backend/oasst_backend/utils/tree_export.py b/backend/oasst_backend/utils/tree_export.py index ee3de9d7..5cd69abe 100644 --- a/backend/oasst_backend/utils/tree_export.py +++ b/backend/oasst_backend/utils/tree_export.py @@ -12,12 +12,15 @@ from pydantic import BaseModel class ExportMessageNode(BaseModel): message_id: str - parent_id: Optional[str] - text: Optional[str] + parent_id: str | None + text: str role: str - review_count: Optional[int] - rank: Optional[int] - replies: Optional[list[ExportMessageNode]] + lang: str | None + review_count: int | None + rank: int | None + synthetic: bool | None + model_name: str | None + replies: list[ExportMessageNode] | None @classmethod def prep_message_export(cls, message: Message) -> ExportMessageNode: @@ -26,14 +29,17 @@ class ExportMessageNode(BaseModel): parent_id=str(message.parent_id) if message.parent_id else None, text=str(message.payload.payload.text), role=message.role, + lang=message.lang, review_count=message.review_count, + synthetic=message.synthetic, + model_name=message.model_name, rank=message.rank, ) class ExportMessageTree(BaseModel): message_tree_id: str - replies: Optional[ExportMessageNode] + prompt: Optional[ExportMessageNode] def build_export_tree(message_tree_id: str, messages: list[Message]) -> ExportMessageTree: From 264e914225611a418de70583e07a71d44ef2244d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 15:07:46 +0100 Subject: [PATCH 02/13] exclude fails_task from default valid labels --- backend/oasst_backend/api/v1/text_labels.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index 2025fd4c..affc81bb 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -48,6 +48,7 @@ def get_valid_lables() -> ValidLabelsResponse: valid_labels=[ LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) for l in TextLabel + if l != TextLabel.fails_task ] ) From 19116f7251b366ac0dae5139e2a10bedc1bf7070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 15:29:38 +0100 Subject: [PATCH 03/13] add optional message_id query param to text_labels/valid_labels endpoint --- backend/oasst_backend/api/v1/text_labels.py | 28 ++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index affc81bb..594ba0df 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -1,13 +1,19 @@ +from typing import Optional +from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps +from oasst_backend.config import settings +from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TextLabel +from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -43,12 +49,28 @@ def label_text( @router.get("/valid_labels") -def get_valid_lables() -> ValidLabelsResponse: +def get_valid_lables( + *, + message_id: Optional[UUID] = None, + db: Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_api_client), +) -> ValidLabelsResponse: + if message_id: + pr = PromptRepository(db, api_client=api_client) + message = pr.fetch_message(message_id=message_id) + if message.parent_id is None: + valid_labels = settings.tree_manager.labels_initial_prompt + elif message.role == "assistant": + valid_labels = settings.tree_manager.labels_assistant_reply + else: + valid_labels = settings.tree_manager.labels_prompter_reply + else: + valid_labels = [l for l in TextLabel if l != TextLabel.fails_task] + return ValidLabelsResponse( valid_labels=[ LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) - for l in TextLabel - if l != TextLabel.fails_task + for l in valid_labels ] ) From 314c590dd24316ccfb110cb753d0aa7234595ca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 16:18:31 +0100 Subject: [PATCH 04/13] include import.py in backend docker image --- docker/Dockerfile.backend | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index c89a0280..3401463c 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -13,5 +13,6 @@ RUN pip install -e /oasst-shared COPY ./backend/alembic /app/alembic COPY ./backend/alembic.ini /app/alembic.ini COPY ./backend/main.py /app/main.py +COPY ./backend/import.py /app/import.py COPY ./backend/oasst_backend /app/oasst_backend COPY ./backend/test_data /app/test_data From ab4dce3f600ca2c8df2c3901f2dc7eff2568e86b Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Sun, 29 Jan 2023 03:07:43 +1100 Subject: [PATCH 05/13] website: Support new widget types for labelling (#966) * website: Support new widget types for labelling Adds proper support for yes/no spam style questions as well as a simple interface for flag style labels. Also cleaned up the Task component to fix some rerender issues. * website: Fix some UI text, adjust yes/no button alignment * website: Remove left over console.log Co-authored-by: notmd <33456881+notmd@users.noreply.github.com> --------- Co-authored-by: notmd <33456881+notmd@users.noreply.github.com> --- .../e2e/tasks/label_assistant_reply.cy.ts | 4 + .../e2e/tasks/label_initial_prompt.cy.ts | 4 + .../e2e/tasks/label_prompter_reply.cy.ts | 4 + website/cypress/e2e/tasks/random.cy.ts | 13 +- website/public/locales/en/common.json | 4 +- website/public/locales/en/labelling.json | 16 ++ website/src/components/FlaggableElement.tsx | 116 ---------- website/src/components/Messages.tsx | 20 +- .../components/Messages/LabelFlagGroup.tsx | 32 +++ .../components/Messages/LabelInputGroup.tsx | 84 +++++++ .../src/components/Messages/LabelPopup.tsx | 34 +-- .../components/Messages/LabelYesNoGroup.tsx | 89 ++++++++ .../components/Messages/MessageTableEntry.tsx | 2 +- ...belInputGroup.tsx => LabelLikertGroup.tsx} | 4 +- .../src/components/Survey/TaskControls.tsx | 46 ++-- website/src/components/Tasks/CreateTask.tsx | 6 +- website/src/components/Tasks/EvaluateTask.tsx | 18 +- .../components/Tasks/LabelTask/LabelTask.tsx | 109 +++++---- website/src/components/Tasks/Task/Task.tsx | 216 +++++++++++------- website/src/types/Tasks.ts | 34 +-- website/types/i18next.d.ts | 2 + 21 files changed, 504 insertions(+), 353 deletions(-) create mode 100644 website/public/locales/en/labelling.json delete mode 100644 website/src/components/FlaggableElement.tsx create mode 100644 website/src/components/Messages/LabelFlagGroup.tsx create mode 100644 website/src/components/Messages/LabelInputGroup.tsx create mode 100644 website/src/components/Messages/LabelYesNoGroup.tsx rename website/src/components/Survey/{LabelInputGroup.tsx => LabelLikertGroup.tsx} (98%) diff --git a/website/cypress/e2e/tasks/label_assistant_reply.cy.ts b/website/cypress/e2e/tasks/label_assistant_reply.cy.ts index 422db37c..18ab807f 100644 --- a/website/cypress/e2e/tasks/label_assistant_reply.cy.ts +++ b/website/cypress/e2e/tasks/label_assistant_reply.cy.ts @@ -11,6 +11,10 @@ describe("labeling assistant replies", () => { // For specific task pages the no task available result is normal. if (type === undefined) return; + cy.get('[data-cy="label-question"]').each((label) => { + // Click the no button, this generally approves the spam check + cy.wrap(label).find('[data-cy="no"]').click(); + }); cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); diff --git a/website/cypress/e2e/tasks/label_initial_prompt.cy.ts b/website/cypress/e2e/tasks/label_initial_prompt.cy.ts index be1cf9bb..f11a068d 100644 --- a/website/cypress/e2e/tasks/label_initial_prompt.cy.ts +++ b/website/cypress/e2e/tasks/label_initial_prompt.cy.ts @@ -11,6 +11,10 @@ describe("labeling initial prompts", () => { // For specific task pages the no task available result is normal. if (type === undefined) return; + cy.get('[data-cy="label-question"]').each((label) => { + // Click the no button, this generally approves the spam check + cy.wrap(label).find('[data-cy="no"]').click(); + }); cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); diff --git a/website/cypress/e2e/tasks/label_prompter_reply.cy.ts b/website/cypress/e2e/tasks/label_prompter_reply.cy.ts index a3c06cb3..23801b57 100644 --- a/website/cypress/e2e/tasks/label_prompter_reply.cy.ts +++ b/website/cypress/e2e/tasks/label_prompter_reply.cy.ts @@ -11,6 +11,10 @@ describe("labeling prompter replies", () => { // For specific task pages the no task available result is normal. if (type === undefined) return; + cy.get('[data-cy="label-question"]').each((label) => { + // Click the no button, this generally approves the spam check + cy.wrap(label).find('[data-cy="no"]').click(); + }); cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); diff --git a/website/cypress/e2e/tasks/random.cy.ts b/website/cypress/e2e/tasks/random.cy.ts index 0074bc53..0ca3c7f5 100644 --- a/website/cypress/e2e/tasks/random.cy.ts +++ b/website/cypress/e2e/tasks/random.cy.ts @@ -44,6 +44,10 @@ describe("handles random tasks", () => { break; } case "label-task": { + cy.get('[data-cy="label-question"]').each((label) => { + // Click the no button, this generally approves the spam check + cy.wrap(label).find('[data-cy="no"]').click(); + }); cy.get('[data-cy="label-options"]').each((label) => { // Click the 4th option cy.wrap(label).find('[data-cy="radio-option"]').eq(3).click(); @@ -55,15 +59,6 @@ describe("handles random tasks", () => { break; } - case "spam-task": { - cy.get('[data-cy="not-spam-button"]').click(); - - cy.get('[data-cy="review"]').click(); - - cy.get('[data-cy="submit"]').click(); - - break; - } case undefined: { throw new Error("No tasks available, but at least create initial prompt expected"); } diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index 0b0f9d37..f8e31c99 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -16,5 +16,7 @@ "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", - "title": "Open Assistant" + "title": "Open Assistant", + "yes": "Yes", + "no": "No" } diff --git a/website/public/locales/en/labelling.json b/website/public/locales/en/labelling.json new file mode 100644 index 00000000..13582c98 --- /dev/null +++ b/website/public/locales/en/labelling.json @@ -0,0 +1,16 @@ +{ + "label_highlighted_yes_no_instruction": "Answer the following question(s) about the highlighted message:", + "label_highlighted_flag_instruction": "Select any that apply to the highlighted message:", + "label_highlighted_likert_instruction": "Rate the highlighted message:", + "label_message_yes_no_instruction": "Answer the following question(s) about the message:", + "label_message_flag_instruction": "Select any that apply to the message:", + "label_message_likert_instruction": "Rate the message:", + "spam.question": "Is the message spam?", + "fails_task.question": "Does the reply fail the propmpters task?", + "not_appropriate": "Not Appropriate", + "pii": "Contains PII", + "hate_speech": "Hate Speech", + "sexual_content": "Sexual Content", + "moral_judgement": "Judges Morality", + "political_content": "Politcal" +} diff --git a/website/src/components/FlaggableElement.tsx b/website/src/components/FlaggableElement.tsx deleted file mode 100644 index d7572080..00000000 --- a/website/src/components/FlaggableElement.tsx +++ /dev/null @@ -1,116 +0,0 @@ -import { - Box, - Button, - Modal, - ModalBody, - ModalCloseButton, - ModalContent, - ModalFooter, - ModalHeader, - ModalOverlay, - Popover, - PopoverAnchor, - PopoverTrigger, - Tooltip, - useColorModeValue, - useDisclosure, -} from "@chakra-ui/react"; -import { AlertCircle } from "lucide-react"; -import { useState } from "react"; -import { get, post } from "src/lib/api"; -import { colors } from "src/styles/Theme/colors"; -import { Message } from "src/types/Conversation"; -import useSWRImmutable from "swr/immutable"; -import useSWRMutation from "swr/mutation"; - -import { LabelInputGroup } from "./Survey/LabelInputGroup"; - -interface Label { - name: string; - display_text: string; - help_text: string; -} - -interface FlaggableElementProps { - children: React.ReactNode; - message: Message; -} - -interface ValidLabelsResponse { - valid_labels: Label[]; -} - -export const FlaggableElement = (props: FlaggableElementProps) => { - const { data: response } = useSWRImmutable("/api/valid_labels", get); - const { isOpen, onOpen, onClose } = useDisclosure(); - const { valid_labels } = response || { valid_labels: [] }; - const [values, setValues] = useState([]); - - const submittable = - values.some((value) => { - return value !== null; - }) && - values.length === valid_labels.length && - valid_labels.length > 0; - - const { trigger } = useSWRMutation("/api/set_label", post, { - onSuccess: onClose, - onError: onClose, - }); - - const submitResponse = () => { - const label_map: Map = new Map(); - console.assert(valid_labels.length === values.length); - values.forEach((value, idx) => { - if (value !== null) { - label_map.set(valid_labels[idx].name, value); - } - }); - trigger({ - message_id: props.message.id, - label_map: Object.fromEntries(label_map), - text: props.message.text, - }); - }; - - return ( - - - {props.children} - - - - - - - - - - - - - - - Select one or more labels that apply. - - - name)} onChange={setValues} /> - - - - - - - - ); -}; diff --git a/website/src/components/Messages.tsx b/website/src/components/Messages.tsx index c9d77e3c..58e0d2be 100644 --- a/website/src/components/Messages.tsx +++ b/website/src/components/Messages.tsx @@ -1,25 +1,7 @@ -import { Box, forwardRef, Grid, useColorMode } from "@chakra-ui/react"; +import { Box, forwardRef, useColorMode } from "@chakra-ui/react"; import { useMemo } from "react"; import { Message } from "src/types/Conversation"; -import { FlaggableElement } from "./FlaggableElement"; - -interface MessagesProps { - messages: Message[]; -} - -export const Messages = ({ messages }: MessagesProps) => { - const items = messages.map((messageProps: Message, i: number) => { - return ( - - - - ); - }); - // Maybe also show a legend of the colors? - return {items}; -}; - export const MessageView = forwardRef, "div">((message: Partial, ref) => { const { colorMode } = useColorMode(); diff --git a/website/src/components/Messages/LabelFlagGroup.tsx b/website/src/components/Messages/LabelFlagGroup.tsx new file mode 100644 index 00000000..fb1158bc --- /dev/null +++ b/website/src/components/Messages/LabelFlagGroup.tsx @@ -0,0 +1,32 @@ +import { Button, Flex } from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { getTypeSafei18nKey } from "src/lib/i18n"; + +interface LabelFlagGroupProps { + values: number[]; + labelNames: string[]; + isEditable?: boolean; + onChange: (values: number[]) => void; +} + +export const LabelFlagGroup = ({ values, labelNames, isEditable = true, onChange }: LabelFlagGroupProps) => { + const { t } = useTranslation("labelling"); + return ( + + {labelNames.map((name, idx) => ( + + ))} + + ); +}; diff --git a/website/src/components/Messages/LabelInputGroup.tsx b/website/src/components/Messages/LabelInputGroup.tsx new file mode 100644 index 00000000..51383128 --- /dev/null +++ b/website/src/components/Messages/LabelInputGroup.tsx @@ -0,0 +1,84 @@ +import { Text, VStack } from "@chakra-ui/react"; +import { Label } from "src/types/Tasks"; + +import { LabelLikertGroup } from "../Survey/LabelLikertGroup"; +import { LabelFlagGroup } from "./LabelFlagGroup"; +import { LabelYesNoGroup } from "./LabelYesNoGroup"; + +export interface LabelInputInstructions { + yesNoInstruction: string; + flagInstruction: string; + likertInstruction: string; +} + +interface LabelInputGroupProps { + values: number[]; + labels: Label[]; + requiredLabels?: string[]; + isEditable?: boolean; + instructions: LabelInputInstructions; + onChange: (values: number[]) => void; +} + +export const LabelInputGroup = ({ + labels, + values, + requiredLabels, + isEditable, + instructions, + onChange, +}: LabelInputGroupProps) => { + const yesNoIndexes = labels.map((label, idx) => (label.widget === "yes_no" ? idx : null)).filter((v) => v !== null); + const flagIndexes = labels.map((label, idx) => (label.widget === "flag" ? idx : null)).filter((v) => v !== null); + const likertIndexes = labels.map((label, idx) => (label.widget === "likert" ? idx : null)).filter((v) => v !== null); + + return ( + + {yesNoIndexes.length > 0 && ( + + {instructions.yesNoInstruction} + values[idx])} + labelNames={yesNoIndexes.map((idx) => labels[idx].name)} + isEditable={isEditable} + requiredLabels={requiredLabels} + onChange={(yesNoValues) => { + const newValues = values.slice(); + yesNoIndexes.forEach((idx, yesNoIndex) => (newValues[idx] = yesNoValues[yesNoIndex])); + onChange(newValues); + }} + /> + + )} + {flagIndexes.length > 0 && ( + + {instructions.flagInstruction} + values[idx])} + labelNames={flagIndexes.map((idx) => labels[idx].name)} + isEditable={isEditable} + onChange={(flagValues) => { + const newValues = values.slice(); + flagIndexes.forEach((idx, flagIndex) => (newValues[idx] = flagValues[flagIndex])); + onChange(newValues); + }} + /> + + )} + {likertIndexes.length > 0 && ( + + {instructions.likertInstruction} + labels[idx].name)} + isEditable={isEditable} + onChange={(likertValues) => { + const newValues = values.slice(); + likertIndexes.forEach((idx, likertIndex) => (newValues[idx] = likertValues[likertIndex])); + onChange(newValues); + }} + /> + + )} + + ); +}; diff --git a/website/src/components/Messages/LabelPopup.tsx b/website/src/components/Messages/LabelPopup.tsx index b2b95278..ac564e6d 100644 --- a/website/src/components/Messages/LabelPopup.tsx +++ b/website/src/components/Messages/LabelPopup.tsx @@ -9,9 +9,10 @@ import { ModalOverlay, } from "@chakra-ui/react"; import { useTranslation } from "next-i18next"; -import { useState } from "react"; -import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; +import { useEffect, useState } from "react"; +import { LabelInputGroup } from "src/components/Messages/LabelInputGroup"; import { get, post } from "src/lib/api"; +import { Label } from "src/types/Tasks"; import useSWRImmutable from "swr/immutable"; import useSWRMutation from "swr/mutation"; @@ -21,21 +22,19 @@ interface LabelMessagePopupProps { onClose: () => void; } -interface Label { - name: string; - display_text: string; - help_text: string; -} - interface ValidLabelsResponse { valid_labels: Label[]; } export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => { - const { t } = useTranslation("message"); + const { t } = useTranslation(); const { data: response } = useSWRImmutable("/api/valid_labels", get); const valid_labels = response?.valid_labels ?? []; - const [values, setValues] = useState(null); + const [values, setValues] = useState(new Array(valid_labels.length).fill(null)); + + useEffect(() => { + setValues(new Array(valid_labels.length).fill(null)); + }, [messageId, valid_labels.length]); const { trigger: setLabels } = useSWRMutation("/api/set_label", post); @@ -60,14 +59,23 @@ export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopu - {t("label_title")} + {t("message:label_title")} - name)} onChange={setValues} /> + diff --git a/website/src/components/Messages/LabelYesNoGroup.tsx b/website/src/components/Messages/LabelYesNoGroup.tsx new file mode 100644 index 00000000..72c40e2b --- /dev/null +++ b/website/src/components/Messages/LabelYesNoGroup.tsx @@ -0,0 +1,89 @@ +import { Button, HStack, Text, Tooltip } from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; +import { getTypeSafei18nKey } from "src/lib/i18n"; + +interface LabelYesNoGroupProps { + values: number[]; + labelNames: string[]; + requiredLabels?: string[]; + isEditable?: boolean; + onChange: (values: number[]) => void; +} + +export const LabelYesNoGroup = ({ + values, + labelNames, + requiredLabels = [], + isEditable = true, + onChange, +}: LabelYesNoGroupProps) => { + const { t } = useTranslation("labelling"); + return ( + <> + {labelNames.map((name, idx) => { + return ( + 0.1 ? true : false} + onChange={(value) => { + const newValues = values.slice(); + newValues[idx] = value; + onChange(newValues); + }} + isEditable={isEditable} + isRequired={requiredLabels.includes(name)} + /> + ); + })} + + ); +}; + +const YesNoQuestion = ({ + isEditable, + question, + value, + isRequired, + onChange, +}: { + isEditable: boolean; + question: string; + value: boolean; + isRequired?: boolean; + onChange: (boolean) => void; +}) => { + const { t } = useTranslation(); + return ( +
+ + {question} + {isRequired ? : undefined} + + + + + +
+ ); +}; + +const RequiredMark = () => ( + + * + +); diff --git a/website/src/components/Messages/MessageTableEntry.tsx b/website/src/components/Messages/MessageTableEntry.tsx index 7202903a..2673ad49 100644 --- a/website/src/components/Messages/MessageTableEntry.tsx +++ b/website/src/components/Messages/MessageTableEntry.tsx @@ -36,7 +36,7 @@ export function MessageTableEntry({ message, enabled, highlight }: MessageTableE const router = useRouter(); const [emojiState, setEmojis] = useState({ emojis: {}, user_emojis: [] }); useEffect(() => { - setEmojis({ emojis: message.emojis, user_emojis: message.user_emojis }); + setEmojis({ emojis: message.emojis || {}, user_emojis: message.user_emojis || [] }); }, [message.emojis, message.user_emojis]); const goToMessage = useCallback(() => router.push(`/messages/${message.id}`), [router, message.id]); diff --git a/website/src/components/Survey/LabelInputGroup.tsx b/website/src/components/Survey/LabelLikertGroup.tsx similarity index 98% rename from website/src/components/Survey/LabelInputGroup.tsx rename to website/src/components/Survey/LabelLikertGroup.tsx index 94fcc48e..fc959a26 100644 --- a/website/src/components/Survey/LabelInputGroup.tsx +++ b/website/src/components/Survey/LabelLikertGroup.tsx @@ -135,7 +135,7 @@ const getLabelInfo = (label: string): LabelInfo => { oneDescription: ["Contains text which is incorrect or misleading"], inverted: true, }; - case "helpful": + case "helpfulness": return { zeroText: "Unhelful", zeroDescription: [], @@ -186,7 +186,7 @@ const getLabelInfo = (label: string): LabelInfo => { } }; -export const LabelInputGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => { +export const LabelLikertGroup = ({ labelIDs, onChange, isEditable = true }: LabelInputGroupProps) => { const [labelValues, setLabelValues] = useState(Array.from({ length: labelIDs.length }).map(() => null)); const cardColor = useColorModeValue("gray.50", "gray.800"); diff --git a/website/src/components/Survey/TaskControls.tsx b/website/src/components/Survey/TaskControls.tsx index 76aaee8b..a3c3ffdd 100644 --- a/website/src/components/Survey/TaskControls.tsx +++ b/website/src/components/Survey/TaskControls.tsx @@ -4,12 +4,10 @@ import { SkipButton } from "src/components/Buttons/Skip"; import { SubmitButton } from "src/components/Buttons/Submit"; import { TaskInfo } from "src/components/TaskInfo/TaskInfo"; import { TaskStatus } from "src/components/Tasks/Task"; +import { BaseTask } from "src/types/Task"; export interface TaskControlsProps { - // we need a task type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - task: any; - className?: string; + task: BaseTask; taskStatus: TaskStatus; onEdit: () => void; onReview: () => void; @@ -17,7 +15,7 @@ export interface TaskControlsProps { onSkip: (reason: string) => void; } -export const TaskControls = (props: TaskControlsProps) => { +export const TaskControls = ({ task, taskStatus, onEdit, onReview, onSubmit, onSkip }: TaskControlsProps) => { const backgroundColor = useColorModeValue("white", "gray.800"); return ( @@ -31,38 +29,32 @@ export const TaskControls = (props: TaskControlsProps) => { shadow="base" gap="4" > - + - {props.taskStatus === "REVIEW" || props.taskStatus === "SUBMITTED" ? ( + {taskStatus.mode === "EDIT" ? ( <> - - } - /> - + - Submit + Review ) : ( <> - + + } /> + - Review + Submit )} diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 36493e27..5276a183 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -7,6 +7,8 @@ import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; import { getTypeSafei18nKey } from "src/lib/i18n"; +import { TaskType } from "src/types/Task"; +import { CreateAssistantReplyTask, CreateInitialPromptTask, CreatePrompterReplyTask } from "src/types/Tasks"; export const CreateTask = ({ task, @@ -15,7 +17,7 @@ export const CreateTask = ({ isDisabled, onReplyChanged, onValidityChanged, -}: TaskSurveyProps<{ text: string }>) => { +}: TaskSurveyProps) => { const { t, i18n } = useTranslation(["tasks", "common"]); const cardColor = useColorModeValue("gray.50", "gray.800"); const titleColor = useColorModeValue("gray.800", "gray.300"); @@ -39,7 +41,7 @@ export const CreateTask = ({ <> - {!!task.conversation && ( + {task.type !== TaskType.initial_prompt && ( diff --git a/website/src/components/Tasks/EvaluateTask.tsx b/website/src/components/Tasks/EvaluateTask.tsx index 4be86dd0..b554ffd4 100644 --- a/website/src/components/Tasks/EvaluateTask.tsx +++ b/website/src/components/Tasks/EvaluateTask.tsx @@ -5,6 +5,8 @@ import { Sortable } from "src/components/Sortable/Sortable"; import { SurveyCard } from "src/components/Survey/SurveyCard"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; +import { TaskType } from "src/types/Task"; +import { RankAssistantRepliesTask, RankInitialPromptsTask, RankPrompterRepliesTask } from "src/types/Tasks"; export const EvaluateTask = ({ task, @@ -13,19 +15,25 @@ export const EvaluateTask = ({ isDisabled, onReplyChanged, onValidityChanged, -}: TaskSurveyProps<{ ranking: number[] }>) => { +}: TaskSurveyProps< + RankInitialPromptsTask | RankAssistantRepliesTask | RankPrompterRepliesTask, + { ranking: number[] } +>) => { const cardColor = useColorModeValue("gray.50", "gray.800"); const [ranking, setRanking] = useState(null); let messages = []; - if (task.conversation) { + if (task.type !== TaskType.rank_initial_prompts) { messages = task.conversation.messages; } useEffect(() => { if (ranking === null) { - const defaultRanking = (task.replies ?? task.prompts).map((_, idx) => idx); - onReplyChanged({ ranking: defaultRanking }); + if (task.type === TaskType.rank_initial_prompts) { + onReplyChanged({ ranking: task.prompts.map((_, idx) => idx) }); + } else { + onReplyChanged({ ranking: task.replies.map((_, idx) => idx) }); + } onValidityChanged("DEFAULT"); } else { onReplyChanged({ ranking }); @@ -33,7 +41,7 @@ export const EvaluateTask = ({ } }, [task, ranking, onReplyChanged, onValidityChanged]); - const sortables = task.replies ? "replies" : "prompts"; + const sortables = task.type === TaskType.rank_initial_prompts ? "prompts" : "replies"; return (
diff --git a/website/src/components/Tasks/LabelTask/LabelTask.tsx b/website/src/components/Tasks/LabelTask/LabelTask.tsx index 10ea76fb..33152ba1 100644 --- a/website/src/components/Tasks/LabelTask/LabelTask.tsx +++ b/website/src/components/Tasks/LabelTask/LabelTask.tsx @@ -1,11 +1,18 @@ -import { Box, Button, Flex, HStack, Text, useColorModeValue } from "@chakra-ui/react"; +import { Box, useBoolean, useColorModeValue } from "@chakra-ui/react"; +import { useTranslation } from "next-i18next"; import { useEffect, useState } from "react"; import { MessageView } from "src/components/Messages"; +import { LabelInputGroup } from "src/components/Messages/LabelInputGroup"; import { MessageTable } from "src/components/Messages/MessageTable"; -import { LabelInputGroup } from "src/components/Survey/LabelInputGroup"; import { TwoColumnsWithCards } from "src/components/Survey/TwoColumnsWithCards"; import { TaskSurveyProps } from "src/components/Tasks/Task"; import { TaskHeader } from "src/components/Tasks/TaskHeader"; +import { TaskType } from "src/types/Task"; +import { LabelTaskType } from "src/types/Tasks"; + +const isRequired = (labelName: string, requiredLabels?: string[]) => { + return requiredLabels ? requiredLabels.includes(labelName) : false; +}; export const LabelTask = ({ task, @@ -13,15 +20,33 @@ export const LabelTask = ({ isEditable, onReplyChanged, onValidityChanged, -}: TaskSurveyProps<{ text: string; labels: Record; message_id: string }>) => { - const [sliderValues, setSliderValues] = useState(new Array(task.valid_labels.length).fill(null)); +}: TaskSurveyProps; message_id: string }>) => { + const { t } = useTranslation("labelling"); + const [values, setValues] = useState(new Array(task.labels.length).fill(null)); + const [userInputMade, setUserInputMade] = useBoolean(false); + // Initial setup to run when the task changes useEffect(() => { - console.assert(task.valid_labels.length === sliderValues.length); - const labels = Object.fromEntries(task.valid_labels.map((label, i) => [label, sliderValues[i]])); - onReplyChanged({ labels, text: task.reply || task.prompt, message_id: task.message_id }); - onValidityChanged(sliderValues.every((value) => value !== null) ? "VALID" : "INVALID"); - }, [task, sliderValues, onReplyChanged, onValidityChanged]); + setValues(new Array(task.labels.length).fill(null)); + onValidityChanged(task.labels.some(({ name }) => isRequired(name, task.mandatory_labels)) ? "INVALID" : "DEFAULT"); + setUserInputMade.off(); + }, [task, setUserInputMade, onValidityChanged]); + + // Update the reply and validity when the values change + useEffect(() => { + onReplyChanged({ + text: "unused?", + labels: Object.fromEntries(task.labels.map(({ name }, idx) => [name, values[idx] || 0])), + message_id: task.message_id, + }); + onValidityChanged( + task.labels.some(({ name }, idx) => values[idx] === null && isRequired(name, task.mandatory_labels)) + ? "INVALID" + : userInputMade + ? "VALID" + : "DEFAULT" + ); + }, [task, values, onReplyChanged, userInputMade, onValidityChanged]); const cardColor = useColorModeValue("gray.50", "gray.800"); const isSpamTask = task.mode === "simple" && task.valid_labels.length === 1 && task.valid_labels[0] === "spam"; @@ -31,12 +56,9 @@ export const LabelTask = ({ <> - {task.conversation ? ( + {task.type !== TaskType.label_initial_prompt ? ( - + ) : ( @@ -44,51 +66,22 @@ export const LabelTask = ({ )} - {isSpamTask ? ( - setSliderValues([value])} - isEditable={isEditable} - /> - ) : ( - - The highlighted message: - - - )} + { + setValues(values); + setUserInputMade.on(); + }} + />
); }; - -const SpamTaskInput = ({ - isEditable, - value, - onChange, -}: { - isEditable: boolean; - value: number; - onChange: (number) => void; -}) => { - return ( - - Is the highlighted message spam? - - - - ); -}; diff --git a/website/src/components/Tasks/Task/Task.tsx b/website/src/components/Tasks/Task/Task.tsx index ae82ef97..51ba6fa3 100644 --- a/website/src/components/Tasks/Task/Task.tsx +++ b/website/src/components/Tasks/Task/Task.tsx @@ -1,5 +1,6 @@ import { useTranslation } from "next-i18next"; -import { useRef, useState } from "react"; +import { useCallback, useEffect, useReducer } from "react"; +import { useMemo, useRef } from "react"; import { TaskControls } from "src/components/Survey/TaskControls"; import { CreateTask } from "src/components/Tasks/CreateTask"; import { EvaluateTask } from "src/components/Tasks/EvaluateTask"; @@ -8,15 +9,52 @@ import { TaskCategory, TaskInfo, TaskInfos } from "src/components/Tasks/TaskType import { UnchangedWarning } from "src/components/Tasks/UnchangedWarning"; import { post } from "src/lib/api"; import { getTypeSafei18nKey } from "src/lib/i18n"; -import { TaskContent, TaskReplyValidity } from "src/types/Task"; +import { BaseTask, TaskContent, TaskReplyValidity } from "src/types/Task"; import useSWRMutation from "swr/mutation"; -export type TaskStatus = "NOT_SUBMITTABLE" | "DEFAULT" | "VALID" | "REVIEW" | "SUBMITTED"; +interface EditMode { + mode: "EDIT"; + replyValidity: TaskReplyValidity; +} +interface ReviewMode { + mode: "REVIEW"; +} +interface DefaultWarnMode { + mode: "DEFAULT_WARN"; +} +interface SubmittedMode { + mode: "SUBMITTED"; +} -export interface TaskSurveyProps { - // we need a task type - // eslint-disable-next-line @typescript-eslint/no-explicit-any - task: any; +export type TaskStatus = EditMode | DefaultWarnMode | ReviewMode | SubmittedMode; + +interface NewTask { + action: "NEW_TASK"; +} + +interface Review { + action: "REVIEW"; +} + +interface SetSubmitted { + action: "SET_SUBMITTED"; +} + +interface ReturnToEdit { + action: "RETURN_EDIT"; +} + +interface AcceptDefault { + action: "ACCEPT_DEFAULT"; +} + +interface UpdateValidity { + action: "UPDATE_VALIDITY"; + replyValidity: TaskReplyValidity; +} + +export interface TaskSurveyProps { + task: TaskType; taskType: TaskInfo; isEditable: boolean; isDisabled?: boolean; @@ -26,13 +64,63 @@ export interface TaskSurveyProps { export const Task = ({ frontendId, task, trigger, mutate }) => { const { t } = useTranslation("tasks"); - const [taskStatus, setTaskStatus] = useState("NOT_SUBMITTABLE"); + const [taskStatus, taskEvent] = useReducer( + ( + status: TaskStatus, + event: NewTask | UpdateValidity | AcceptDefault | Review | ReturnToEdit | SetSubmitted + ): TaskStatus => { + switch (event.action) { + case "NEW_TASK": + return { mode: "EDIT", replyValidity: "INVALID" }; + case "UPDATE_VALIDITY": + return status.mode === "EDIT" ? { mode: "EDIT", replyValidity: event.replyValidity } : status; + case "ACCEPT_DEFAULT": + return status.mode === "DEFAULT_WARN" ? { mode: "REVIEW" } : status; + case "REVIEW": { + if (status.mode === "EDIT") { + switch (status.replyValidity) { + case "DEFAULT": + return { mode: "DEFAULT_WARN" }; + case "VALID": + return { mode: "REVIEW" }; + } + } + return status; + } + case "RETURN_EDIT": { + switch (status.mode) { + case "REVIEW": + return { mode: "EDIT", replyValidity: "VALID" }; + case "DEFAULT_WARN": + return { mode: "EDIT", replyValidity: "DEFAULT" }; + default: + return status; + } + } + case "SET_SUBMITTED": { + return status.mode === "REVIEW" ? { mode: "SUBMITTED" } : status; + } + } + }, + { mode: "EDIT", replyValidity: "INVALID" } + ); + const replyContent = useRef(null); - const [showUnchangedWarning, setShowUnchangedWarning] = useState(false); + const updateValidity = useCallback( + (replyValidity: TaskReplyValidity) => taskEvent({ action: "UPDATE_VALIDITY", replyValidity }), + [taskEvent] + ); + + useEffect(() => { + taskEvent({ action: "NEW_TASK" }); + }, [task.id, updateValidity]); const rootEl = useRef(null); - const taskType = TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode); + const taskType = useMemo( + () => TaskInfos.find((taskType) => taskType.type === task.type && taskType.mode === task.mode), + [task.type, task.mode] + ); const { trigger: sendRejection } = useSWRMutation("/api/reject_task", post, { onSuccess: async () => { @@ -47,79 +135,36 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { }); }; - const edit_mode = taskStatus === "NOT_SUBMITTABLE" || taskStatus === "DEFAULT" || taskStatus === "VALID"; - const submitted = taskStatus === "SUBMITTED"; - - const onValidityChanged = (validity: TaskReplyValidity) => { - if (!edit_mode) return; - switch (validity) { - case "DEFAULT": - if (taskStatus !== "DEFAULT") setTaskStatus("DEFAULT"); - break; - case "VALID": - if (taskStatus !== "VALID") setTaskStatus("VALID"); - break; - case "INVALID": - if (taskStatus !== "NOT_SUBMITTABLE") setTaskStatus("NOT_SUBMITTABLE"); - break; - } - }; - - const onReplyChanged = (content: TaskContent) => { - replyContent.current = content; - }; - - const reviewResponse = () => { - switch (taskStatus) { - case "DEFAULT": - setShowUnchangedWarning(true); - break; - case "VALID": - setTaskStatus("REVIEW"); - break; - default: - return; - } - }; - - const editResponse = () => { - switch (taskStatus) { - case "REVIEW": - setTaskStatus("VALID"); - break; - default: - return; - } - }; + const onReplyChanged = useCallback( + (content: TaskContent) => { + replyContent.current = content; + }, + [replyContent] + ); const submitResponse = () => { - switch (taskStatus) { - case "REVIEW": { - trigger({ - id: frontendId, - update_type: taskType.update_type, - content: replyContent.current, - }); - setTaskStatus("SUBMITTED"); - scrollToTop(rootEl.current); - break; - } - default: - return; + if (taskStatus.mode === "REVIEW") { + trigger({ + id: frontendId, + update_type: taskType.update_type, + content: replyContent.current, + }); + taskEvent({ action: "SET_SUBMITTED" }); + scrollToTop(rootEl.current); } }; - function taskTypeComponent() { + const taskTypeComponent = useMemo(() => { switch (taskType.category) { case TaskCategory.Create: return ( ); case TaskCategory.Evaluate: @@ -127,10 +172,10 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { ); case TaskCategory.Label: @@ -138,37 +183,34 @@ export const Task = ({ frontendId, task, trigger, mutate }) => { ); } - } + }, [task, taskType, taskStatus.mode, onReplyChanged, updateValidity]); return (
- {taskTypeComponent()} + {taskTypeComponent} taskEvent({ action: "RETURN_EDIT" })} + onReview={() => taskEvent({ action: "REVIEW" })} onSubmit={submitResponse} onSkip={rejectTask} /> setShowUnchangedWarning(false)} + onClose={() => taskEvent({ action: "RETURN_EDIT" })} onContinueAnyway={() => { - if (taskStatus === "DEFAULT") { - setTaskStatus("REVIEW"); - setShowUnchangedWarning(false); - } + taskEvent({ action: "ACCEPT_DEFAULT" }); }} />
diff --git a/website/src/types/Tasks.ts b/website/src/types/Tasks.ts index bbbe3a67..5fbc84c7 100644 --- a/website/src/types/Tasks.ts +++ b/website/src/types/Tasks.ts @@ -33,31 +33,39 @@ export interface RankPrompterRepliesTask extends BaseTask { replies: string[]; } -export interface LabelAssistantReplyTask extends BaseTask { +export interface Label { + display_text: string; + help_text: string; + name: string; + widget: "flag" | "yes_no" | "likert"; +} + +export interface BaseLabelTask extends BaseTask { + message_id: string; + labels: Label[]; + valid_labels: string[]; + disposition: "spam" | "quality"; + mode: "simple" | "full"; + mandatory_labels?: string[]; +} + +export interface LabelAssistantReplyTask extends BaseLabelTask { type: TaskType.label_assistant_reply; - message_id: string; conversation: Conversation; reply_message: Message; reply: string; - valid_labels: string[]; - mode: "simple" | "full"; - mandatory_labels?: string[]; } -export interface LabelPrompterReplyTask extends BaseTask { +export interface LabelPrompterReplyTask extends BaseLabelTask { type: TaskType.label_prompter_reply; - message_id: string; conversation: Conversation; reply_message: Message; reply: string; - valid_labels: string[]; - mode: "simple" | "full"; - mandatory_labels?: string[]; } -export interface LabelInitialPromptTask extends BaseTask { +export interface LabelInitialPromptTask extends BaseLabelTask { type: TaskType.label_initial_prompt; - message_id: string; - valid_labels: string[]; prompt: string; } + +export type LabelTaskType = LabelAssistantReplyTask | LabelPrompterReplyTask | LabelInitialPromptTask; diff --git a/website/types/i18next.d.ts b/website/types/i18next.d.ts index a00b1a80..0a2cf10a 100644 --- a/website/types/i18next.d.ts +++ b/website/types/i18next.d.ts @@ -3,6 +3,7 @@ import type dashboard from "public/locales/en/dashboard.json"; import type index from "public/locales/en/index.json"; import type leaderboard from "public/locales/en/leaderboard.json"; import type message from "public/locales/en/message.json"; +import type labelling from "public/locales/en/labelling.json"; import type tasks from "public/locales/en/tasks.json"; declare module "i18next" { @@ -14,6 +15,7 @@ declare module "i18next" { leaderboard: typeof leaderboard; tasks: typeof tasks; message: typeof message; + labelling: typeof labelling; }; } } From 54503b7e1b833fd184af100c3986e7709895b78e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 19:10:03 +0100 Subject: [PATCH 06/13] reduce userstats cron defaults, fix reference error --- backend/oasst_backend/config.py | 6 +++--- backend/oasst_backend/prompt_repository.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 5c566b1e..12cd3e89 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -179,9 +179,9 @@ class Settings(BaseSettings): tree_manager: Optional[TreeManagerConfiguration] = TreeManagerConfiguration() - USER_STATS_INTERVAL_DAY: int = 15 # minutes - USER_STATS_INTERVAL_WEEK: int = 60 # minutes - USER_STATS_INTERVAL_MONTH: int = 120 # minutes + USER_STATS_INTERVAL_DAY: int = 5 # minutes + USER_STATS_INTERVAL_WEEK: int = 15 # minutes + USER_STATS_INTERVAL_MONTH: int = 60 # minutes USER_STATS_INTERVAL_TOTAL: int = 240 # minutes @validator( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index c69f7340..50808285 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -481,6 +481,7 @@ class PromptRepository: task_id=task.id if task else None, ) + message: Message = None if message_id: if not task: if text_labels.is_report is True: From 574737348047b7ecad7d0e085a23e38a6b3aae63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 19:57:44 +0100 Subject: [PATCH 07/13] fix typo Unhelful --- website/src/components/Survey/LabelLikertGroup.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/src/components/Survey/LabelLikertGroup.tsx b/website/src/components/Survey/LabelLikertGroup.tsx index fc959a26..f7f14e7d 100644 --- a/website/src/components/Survey/LabelLikertGroup.tsx +++ b/website/src/components/Survey/LabelLikertGroup.tsx @@ -137,7 +137,7 @@ const getLabelInfo = (label: string): LabelInfo => { }; case "helpfulness": return { - zeroText: "Unhelful", + zeroText: "Unhelpful", zeroDescription: [], oneText: "Helpful", oneDescription: ["Completes the task to a high standard"], From eb4c41e3c6e89b15ad8205a99821f6e9d61e0714 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 28 Jan 2023 20:23:39 +0100 Subject: [PATCH 08/13] send full conversation (including last-message) in label tasks --- backend/oasst_backend/tree_manager.py | 2 +- oasst-shared/oasst_shared/schemas/protocol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index bcab02c3..1f0f1784 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -361,7 +361,7 @@ class TreeManager: random_reply_message = random.choice(replies_need_review) messages = self.pr.fetch_message_conversation(random_reply_message) - conversation = prepare_conversation(messages[:-1]) + conversation = prepare_conversation(messages) message = messages[-1] self.cfg.p_full_labeling_review_reply_prompter: float = 0.1 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index cbb4d29c..4cdea856 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -258,7 +258,7 @@ class LabelConversationReplyTask(AbstractLabelTask): """A task to label a reply to a conversation.""" type: Literal["label_conversation_reply"] = "label_conversation_reply" - conversation: Conversation # the conversation so far + conversation: Conversation # the conversation so far (new: including the reply message) reply_message: Optional[ConversationMessage] reply: str From 86911b445383eee974f5da5b7d972acb6d569590 Mon Sep 17 00:00:00 2001 From: rjmacarthy Date: Sat, 28 Jan 2023 20:17:23 +0000 Subject: [PATCH 09/13] Fix locale issue on messages/id page Pre-commit --- website/public/locales/en/common.json | 4 ++-- website/public/locales/en/message.json | 10 ++++++---- website/src/pages/messages/[id]/index.tsx | 8 +++++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/website/public/locales/en/common.json b/website/public/locales/en/common.json index f8e31c99..6298a32c 100644 --- a/website/public/locales/en/common.json +++ b/website/public/locales/en/common.json @@ -11,12 +11,12 @@ "legal": "Legal", "loading": "Loading...", "more_information": "More Information", + "no": "No", "privacy_policy": "Privacy Policy", "report_a_bug": "Report a Bug", "sign_in": "Sign In", "sign_out": "Sign Out", "terms_of_service": "Terms of Service", "title": "Open Assistant", - "yes": "Yes", - "no": "No" + "yes": "Yes" } diff --git a/website/public/locales/en/message.json b/website/public/locales/en/message.json index 45ea04a1..e16e3c0a 100644 --- a/website/public/locales/en/message.json +++ b/website/public/locales/en/message.json @@ -1,11 +1,13 @@ { - "reactions": "Reactions", "label_action": "Label", "label_title": "Label", - "submit_labels": "Submit", + "message": "Message", "open_new_tab_action": "Open in new tab", - "report_title": "Report", + "parent": "Parent", + "reactions": "Reactions", "report_action": "Report", "report_placeholder": "Why should this message be reviewed?", - "send_report": "Send" + "report_title": "Report", + "send_report": "Send", + "submit_labels": "Submit" } diff --git a/website/src/pages/messages/[id]/index.tsx b/website/src/pages/messages/[id]/index.tsx index 4513fc37..158d28e8 100644 --- a/website/src/pages/messages/[id]/index.tsx +++ b/website/src/pages/messages/[id]/index.tsx @@ -1,5 +1,6 @@ import { Box, Text, useColorModeValue } from "@chakra-ui/react"; import Head from "next/head"; +import { useTranslation } from "next-i18next"; import { serverSideTranslations } from "next-i18next/serverSideTranslations"; import { getDashboardLayout } from "src/components/Layout"; import { MessageLoading } from "src/components/Loading/MessageLoading"; @@ -10,6 +11,7 @@ import { Message } from "src/types/Conversation"; import useSWRImmutable from "swr/immutable"; const MessageDetail = ({ id }: { id: string }) => { + const { t } = useTranslation(["message", "common"]); const backgroundColor = useColorModeValue("white", "gray.800"); const { isLoading: isLoadingParent, data: parent } = useSWRImmutable(`/api/messages/${id}/parent`, get); @@ -20,7 +22,7 @@ const MessageDetail = ({ id }: { id: string }) => { return ( <> - Open Assistant + {t("common:title")} { <> - Parent + {t("parent")} @@ -54,7 +56,7 @@ MessageDetail.getLayout = (page) => getDashboardLayout(page); export const getServerSideProps = async ({ locale, query }) => ({ props: { id: query.id, - ...(await serverSideTranslations(locale, ["common"])), + ...(await serverSideTranslations(locale, ["common", "message"])), }, }); From f885e279618396cb9cd17ca1a75f7cbcad3ca172 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 29 Jan 2023 00:37:28 +0100 Subject: [PATCH 10/13] added horizontal layout to inference dev setup --- inference/full-dev-setup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh index 98a5b173..2251c62b 100755 --- a/inference/full-dev-setup.sh +++ b/inference/full-dev-setup.sh @@ -16,4 +16,5 @@ tmux split-window -h tmux send-keys "cd text-client" C-m tmux send-keys "sleep 5" C-m tmux send-keys "python __main__.py" C-m +tmux select-layout even-horizontal tmux attach-session -t "inference-dev-setup" From ccf96fd843f01d04b20c54da74b80dc97bc6c35d Mon Sep 17 00:00:00 2001 From: Alan Jean Date: Sun, 29 Jan 2023 06:28:00 +0400 Subject: [PATCH 11/13] Refactor tasks translation file to have one placeholder per task-type Fixes #979 --- website/public/locales/en/tasks.json | 10 ++++++---- website/src/components/Tasks/CreateTask.tsx | 6 +++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/website/public/locales/en/tasks.json b/website/public/locales/en/tasks.json index 553a26d2..53cfa088 100644 --- a/website/public/locales/en/tasks.json +++ b/website/public/locales/en/tasks.json @@ -1,5 +1,4 @@ { - "write_initial_prompt": "Write your prompt here...", "default": { "unchanged_title": "No changes", "unchanged_message": "Are you sure you would like to continue?" @@ -12,18 +11,21 @@ "label": "Create Initial Prompts", "desc": "Write initial prompts to help Open Assistant to try replying to diverse messages.", "overview": "Create an initial message to send to the assistant", - "instruction": "Provide the initial prompts" + "instruction": "Provide the initial prompts", + "response_placeholder": "Write your prompt here..." }, "reply_as_user": { "label": "Reply as User", "desc": "Chat with Open Assistant and help improve it's responses as you interact with it.", "overview": "Given the following conversation, provide an adequate reply", - "instruction": "Provide the user's reply" + "instruction": "Provide the user's reply", + "response_placeholder": "Write your reply here..." }, "reply_as_assistant": { "label": "Reply as Assistant", "desc": "Help Open Assistant improve its responses to conversations with other users.", - "overview": "Given the following conversation, provide an adequate reply" + "overview": "Given the following conversation, provide an adequate reply", + "response_placeholder": "Write your reply here..." }, "rank_user_replies": { "label": "Rank User Replies", diff --git a/website/src/components/Tasks/CreateTask.tsx b/website/src/components/Tasks/CreateTask.tsx index 5276a183..a3c7f878 100644 --- a/website/src/components/Tasks/CreateTask.tsx +++ b/website/src/components/Tasks/CreateTask.tsx @@ -58,7 +58,11 @@ export const CreateTask = ({ text={inputText} onTextChange={textChangeHandler} thresholds={{ low: 20, medium: 40, goal: 50 }} - textareaProps={{ placeholder: t("tasks:write_initial_prompt"), isDisabled, isReadOnly: !isEditable }} + textareaProps={{ + placeholder: t(getTypeSafei18nKey(`tasks:${taskType.id}.response_placeholder`)), + isDisabled, + isReadOnly: !isEditable, + }} /> From bcecf257c745429e0ee12e8d3eb8cea4acc31442 Mon Sep 17 00:00:00 2001 From: Alan Jean Date: Sun, 29 Jan 2023 05:45:00 +0400 Subject: [PATCH 12/13] Fix two typos in the labelling translation file Fixes #981 --- website/public/locales/en/labelling.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/website/public/locales/en/labelling.json b/website/public/locales/en/labelling.json index 13582c98..e2ecc8fb 100644 --- a/website/public/locales/en/labelling.json +++ b/website/public/locales/en/labelling.json @@ -6,11 +6,11 @@ "label_message_flag_instruction": "Select any that apply to the message:", "label_message_likert_instruction": "Rate the message:", "spam.question": "Is the message spam?", - "fails_task.question": "Does the reply fail the propmpters task?", + "fails_task.question": "Does the reply fail the prompter's task?", "not_appropriate": "Not Appropriate", "pii": "Contains PII", "hate_speech": "Hate Speech", "sexual_content": "Sexual Content", "moral_judgement": "Judges Morality", - "political_content": "Politcal" + "political_content": "Political" } From 8d45a989d6e0f9f720642ba41bb70edb8ade9b4b Mon Sep 17 00:00:00 2001 From: Adrian Cowan Date: Sun, 29 Jan 2023 02:39:49 +1100 Subject: [PATCH 13/13] website: request message specific labels for on demand labelling --- website/src/components/Messages/LabelPopup.tsx | 2 +- website/src/lib/oasst_api_client.ts | 4 ++-- website/src/pages/api/valid_labels.ts | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/website/src/components/Messages/LabelPopup.tsx b/website/src/components/Messages/LabelPopup.tsx index ac564e6d..2b6232ee 100644 --- a/website/src/components/Messages/LabelPopup.tsx +++ b/website/src/components/Messages/LabelPopup.tsx @@ -28,7 +28,7 @@ interface ValidLabelsResponse { export const LabelMessagePopup = ({ messageId, show, onClose }: LabelMessagePopupProps) => { const { t } = useTranslation(); - const { data: response } = useSWRImmutable("/api/valid_labels", get); + const { data: response } = useSWRImmutable(`/api/valid_labels?message_id=${messageId}`, get); const valid_labels = response?.valid_labels ?? []; const [values, setValues] = useState(new Array(valid_labels.length).fill(null)); diff --git a/website/src/lib/oasst_api_client.ts b/website/src/lib/oasst_api_client.ts index b9a9489e..51e691b3 100644 --- a/website/src/lib/oasst_api_client.ts +++ b/website/src/lib/oasst_api_client.ts @@ -160,8 +160,8 @@ export class OasstApiClient { /** * Returns the valid labels for messages. */ - async fetch_valid_text(): Promise { - return this.get(`/api/v1/text_labels/valid_labels`); + async fetch_valid_text(messageId?: string): Promise { + return this.get("/api/v1/text_labels/valid_labels", { message_id: messageId }); } /** diff --git a/website/src/pages/api/valid_labels.ts b/website/src/pages/api/valid_labels.ts index dca92d90..e195ef29 100644 --- a/website/src/pages/api/valid_labels.ts +++ b/website/src/pages/api/valid_labels.ts @@ -5,8 +5,9 @@ import { createApiClient } from "src/lib/oasst_client_factory"; * Returns the set of valid labels that can be applied to messages. */ const handler = withoutRole("banned", async (req, res, token) => { + const { message_id } = req.query; const client = await createApiClient(token); - const valid_labels = await client.fetch_valid_text(); + const valid_labels = await client.fetch_valid_text(message_id as string); res.status(200).json(valid_labels); });