diff --git a/backend/alembic/versions/2023_01_24_1134-8ba17b5f467a_add_message_id_to_message_reaction.py b/backend/alembic/versions/2023_01_24_1134-8ba17b5f467a_add_message_id_to_message_reaction.py new file mode 100644 index 00000000..8b877a9f --- /dev/null +++ b/backend/alembic/versions/2023_01_24_1134-8ba17b5f467a_add_message_id_to_message_reaction.py @@ -0,0 +1,34 @@ +"""add message_id to message_reaction + +Revision ID: 8ba17b5f467a +Revises: 160ac010efcc +Create Date: 2023-01-24 11:34:42.167575 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8ba17b5f467a" +down_revision = "160ac010efcc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("message_reaction", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True)) + op.create_index(op.f("ix_message_reaction_message_id"), "message_reaction", ["message_id"], unique=False) + op.add_column("text_labels", sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True)) + op.create_index(op.f("ix_text_labels_task_id"), "text_labels", ["task_id"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_text_labels_task_id"), table_name="text_labels") + op.drop_column("text_labels", "task_id") + op.drop_index(op.f("ix_message_reaction_message_id"), table_name="message_reaction") + op.drop_column("message_reaction", "message_id") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 8ca2a413..d831eca2 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -81,6 +81,7 @@ class Settings(BaseSettings): Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json" ) DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages + DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False DEBUG_SKIP_TOXICITY_CALCULATION: bool = False DEBUG_DATABASE_ECHO: bool = False diff --git a/backend/oasst_backend/models/message_reaction.py b/backend/oasst_backend/models/message_reaction.py index 4c50143e..c77c4c27 100644 --- a/backend/oasst_backend/models/message_reaction.py +++ b/backend/oasst_backend/models/message_reaction.py @@ -26,3 +26,4 @@ class MessageReaction(SQLModel, table=True): payload_type: str = Field(nullable=False, max_length=200) payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") + message_id: Optional[UUID] = Field(nullable=True, index=True) diff --git a/backend/oasst_backend/models/text_labels.py b/backend/oasst_backend/models/text_labels.py index 1d238ef2..aa5940b0 100644 --- a/backend/oasst_backend/models/text_labels.py +++ b/backend/oasst_backend/models/text_labels.py @@ -27,3 +27,4 @@ class TextLabels(SQLModel, table=True): sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True) ) labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False) + task_id: Optional[UUID] = Field(nullable=True, index=True) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index a0bc2ae7..8c746320 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -245,7 +245,7 @@ class PromptRepository: # store reaction to message reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating) - reaction = self.insert_reaction(message.id, reaction_payload) + reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=message.id) if not task.collective: task.done = True self.db.add(task) @@ -295,7 +295,7 @@ class PromptRepository: ranking_parent_id=task_payload.ranking_parent_id, message_tree_id=task_payload.message_tree_id, ) - reaction = self.insert_reaction(task.id, reaction_payload) + reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=parent_msg.id) self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") @@ -313,9 +313,8 @@ class PromptRepository: reaction_payload = db_payload.RankingReactionPayload( ranking=ranking.ranking, ranked_message_ids=ranked_message_ids ) - reaction = self.insert_reaction(task.id, reaction_payload) - # TODO: resolve message_id - self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) + reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=None) + # self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking) logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.") @@ -366,7 +365,9 @@ class PromptRepository: return message_embedding @managed_tx_method(CommitMode.FLUSH) - def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: + def insert_reaction( + self, task_id: UUID, payload: db_payload.ReactionPayload, message_id: Optional[UUID] + ) -> MessageReaction: self.ensure_user_is_enabled() container = PayloadContainer(payload=payload) @@ -376,6 +377,7 @@ class PromptRepository: payload=container, api_client_id=self.api_client.id, payload_type=type(payload).__name__, + message_id=message_id, ) self.db.add(reaction) return reaction @@ -441,6 +443,7 @@ class PromptRepository: user_id=self.user_id, text=text_labels.text, labels=text_labels.labels, + task_id=task.id if task else None, ) if message_id: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index 2e5e3e29..d54df08b 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -682,57 +682,65 @@ class TreeManager: # calculate acceptance based on spam label return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels]) - def query_prompts_need_review(self, lang: str) -> list[Message]: - """ - Select initial prompt messages with less then required rankings in active message tree - (active == True in message_tree_state) - """ + def _query_need_review( + self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str + ) -> list[Message]: - qry = ( + need_review = ( self.db.query(Message) .select_from(MessageTreeState) .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id) .filter( MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW, + MessageTreeState.state == state, not_(Message.review_result), not_(Message.deleted), - Message.review_count < self.cfg.num_reviews_initial_prompt, - Message.parent_id.is_(None), + Message.review_count < required_reviews, Message.lang == lang, ) ) + if root: + need_review = need_review.filter(Message.parent_id.is_(None)) + else: + need_review = need_review.filter(Message.parent_id.is_not(None)) + if not settings.DEBUG_ALLOW_SELF_LABELING: - qry = qry.filter(Message.user_id != self.pr.user_id) + need_review = need_review.filter(Message.user_id != self.pr.user_id) + + if settings.DEBUG_ALLOW_DUPLICATE_TASKS: + qry = need_review + else: + user_id = self.pr.user_id + need_review = need_review.cte(name="need_review") + qry = ( + self.db.query(Message) + .select_entity_from(need_review) + .outerjoin(TextLabels, need_review.c.id == TextLabels.message_id) + .group_by(need_review) + .having( + func.count(TextLabels.id).filter(TextLabels.task_id.is_not(None), TextLabels.user_id == user_id) + == 0 + ) + ) return qry.all() + def query_prompts_need_review(self, lang: str) -> list[Message]: + """ + Select initial prompt messages with less then required rankings in active message tree + (active == True in message_tree_state) + """ + return self._query_need_review( + message_tree_state.State.INITIAL_PROMPT_REVIEW, self.cfg.num_reviews_initial_prompt, True, lang + ) + def query_replies_need_review(self, lang: str) -> list[Message]: """ Select child messages (parent_id IS NOT NULL) with less then required rankings in active message tree (active == True in message_tree_state) """ - - qry = ( - self.db.query(Message) - .select_from(MessageTreeState) - .join(Message, MessageTreeState.message_tree_id == Message.message_tree_id) - .filter( - MessageTreeState.active, - MessageTreeState.state == message_tree_state.State.GROWING, - not_(Message.review_result), - not_(Message.deleted), - Message.review_count < self.cfg.num_reviews_reply, - Message.parent_id.is_not(None), - Message.lang == lang, - ) - ) - - if not settings.DEBUG_ALLOW_SELF_LABELING: - qry = qry.filter(Message.user_id != self.pr.user_id) - - return qry.all() + return self._query_need_review(message_tree_state.State.GROWING, self.cfg.num_reviews_reply, False, lang) _sql_find_incomplete_rankings = """ -- find incomplete rankings @@ -748,17 +756,28 @@ WHERE mts.active -- only consider active trees AND m.parent_id IS NOT NULL -- ignore initial prompts GROUP BY m.parent_id, m.role HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings +""" + + _sql_find_incomplete_rankings_ex = f""" +-- incomplete rankings but exclude of current user +WITH incomplete_rankings AS ({_sql_find_incomplete_rankings}) +SELECT ir.* FROM incomplete_rankings ir + LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload' +GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings +HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0) """ def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]: """Query parents which have childern that need further rankings""" + user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None r = self.db.execute( - text(self._sql_find_incomplete_rankings), + text(self._sql_find_incomplete_rankings_ex), { "num_required_rankings": self.cfg.num_required_rankings, "ranking_state": message_tree_state.State.RANKING, "lang": lang, + "user_id": user_id, }, ) return [IncompleteRankingsRow.from_orm(x) for x in r.all()] @@ -779,17 +798,20 @@ WHERE mts.active -- only consider active trees AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children + AND COUNT(c.id) FILTER (WHERE c.user_id = :user_id) = 0 -- without reply by user """ def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]: """Query parent messages that have not reached the maximum number of replies.""" + user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None r = self.db.execute( text(self._sql_find_extendible_parents), { "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, "lang": lang, + "user_id": user_id, }, ) return [ExtendibleParentRow.from_orm(x) for x in r.all()] @@ -813,12 +835,14 @@ HAVING COUNT(m.id) < mts.goal_tree_size def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]: """Query size of active message trees in growing state.""" + user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None r = self.db.execute( text(self._sql_find_extendible_trees), { "growing_state": message_tree_state.State.GROWING, "num_reviews_reply": self.cfg.num_reviews_reply, "lang": lang, + "user_id": user_id, }, ) return [ActiveTreeSizeRow.from_orm(x) for x in r.all()] diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 803b81e8..34113cfc 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -26,7 +26,6 @@ class CommitMode(IntEnum): * managed_tx_method and async_managed_tx_method methods are decorators functions * to be used on class functions. It expects the Class to have a 'db' Session object * initialised -* TODO: tx method decorator for non class methods """ diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 7366cde6..064fd843 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -7,6 +7,7 @@ pushd "$parent_path/../../backend" export DEBUG_USE_SEED_DATA=True export DEBUG_SKIP_TOXICITY_CALCULATION=True export DEBUG_ALLOW_SELF_LABELING=True +export DEBUG_ALLOW_DUPLICATE_TASKS=True export DEBUG_SKIP_EMBEDDING_COMPUTATION=True uvicorn main:app --reload --port 8080 --host 0.0.0.0