907 avoid duplicate labeling & ranking tasks (#923)

* store message_id in message_reactions and task_id in text_labels

* exclude tasks to which users already responded to

* remove test code

* fix join in find_incomplete_rankings_ex
This commit is contained in:
Andreas Köpf
2023-01-24 17:33:15 +01:00
committed by GitHub
parent d72f7771ca
commit ffaf5c48d1
8 changed files with 102 additions and 38 deletions
@@ -0,0 +1,34 @@
"""add message_id to message_reaction
Revision ID: 8ba17b5f467a
Revises: 160ac010efcc
Create Date: 2023-01-24 11:34:42.167575
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "8ba17b5f467a"
down_revision = "160ac010efcc"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("message_reaction", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.create_index(op.f("ix_message_reaction_message_id"), "message_reaction", ["message_id"], unique=False)
op.add_column("text_labels", sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.create_index(op.f("ix_text_labels_task_id"), "text_labels", ["task_id"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_text_labels_task_id"), table_name="text_labels")
op.drop_column("text_labels", "task_id")
op.drop_index(op.f("ix_message_reaction_message_id"), table_name="message_reaction")
op.drop_column("message_reaction", "message_id")
# ### end Alembic commands ###
+1
View File
@@ -81,6 +81,7 @@ class Settings(BaseSettings):
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
)
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
@@ -26,3 +26,4 @@ class MessageReaction(SQLModel, table=True):
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
message_id: Optional[UUID] = Field(nullable=True, index=True)
@@ -27,3 +27,4 @@ class TextLabels(SQLModel, table=True):
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
)
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
task_id: Optional[UUID] = Field(nullable=True, index=True)
+9 -6
View File
@@ -245,7 +245,7 @@ class PromptRepository:
# store reaction to message
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(message.id, reaction_payload)
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=message.id)
if not task.collective:
task.done = True
self.db.add(task)
@@ -295,7 +295,7 @@ class PromptRepository:
ranking_parent_id=task_payload.ranking_parent_id,
message_tree_id=task_payload.message_tree_id,
)
reaction = self.insert_reaction(task.id, reaction_payload)
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=parent_msg.id)
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
@@ -313,9 +313,8 @@ class PromptRepository:
reaction_payload = db_payload.RankingReactionPayload(
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=None)
# self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
@@ -366,7 +365,9 @@ class PromptRepository:
return message_embedding
@managed_tx_method(CommitMode.FLUSH)
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
def insert_reaction(
self, task_id: UUID, payload: db_payload.ReactionPayload, message_id: Optional[UUID]
) -> MessageReaction:
self.ensure_user_is_enabled()
container = PayloadContainer(payload=payload)
@@ -376,6 +377,7 @@ class PromptRepository:
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
message_id=message_id,
)
self.db.add(reaction)
return reaction
@@ -441,6 +443,7 @@ class PromptRepository:
user_id=self.user_id,
text=text_labels.text,
labels=text_labels.labels,
task_id=task.id if task else None,
)
if message_id:
+55 -31
View File
@@ -682,57 +682,65 @@ class TreeManager:
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
def query_prompts_need_review(self, lang: str) -> list[Message]:
"""
Select initial prompt messages with less then required rankings in active message tree
(active == True in message_tree_state)
"""
def _query_need_review(
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
) -> list[Message]:
qry = (
need_review = (
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
MessageTreeState.state == state,
not_(Message.review_result),
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_initial_prompt,
Message.parent_id.is_(None),
Message.review_count < required_reviews,
Message.lang == lang,
)
)
if root:
need_review = need_review.filter(Message.parent_id.is_(None))
else:
need_review = need_review.filter(Message.parent_id.is_not(None))
if not settings.DEBUG_ALLOW_SELF_LABELING:
qry = qry.filter(Message.user_id != self.pr.user_id)
need_review = need_review.filter(Message.user_id != self.pr.user_id)
if settings.DEBUG_ALLOW_DUPLICATE_TASKS:
qry = need_review
else:
user_id = self.pr.user_id
need_review = need_review.cte(name="need_review")
qry = (
self.db.query(Message)
.select_entity_from(need_review)
.outerjoin(TextLabels, need_review.c.id == TextLabels.message_id)
.group_by(need_review)
.having(
func.count(TextLabels.id).filter(TextLabels.task_id.is_not(None), TextLabels.user_id == user_id)
== 0
)
)
return qry.all()
def query_prompts_need_review(self, lang: str) -> list[Message]:
"""
Select initial prompt messages with less then required rankings in active message tree
(active == True in message_tree_state)
"""
return self._query_need_review(
message_tree_state.State.INITIAL_PROMPT_REVIEW, self.cfg.num_reviews_initial_prompt, True, lang
)
def query_replies_need_review(self, lang: str) -> list[Message]:
"""
Select child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
"""
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
MessageTreeState.state == message_tree_state.State.GROWING,
not_(Message.review_result),
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_reply,
Message.parent_id.is_not(None),
Message.lang == lang,
)
)
if not settings.DEBUG_ALLOW_SELF_LABELING:
qry = qry.filter(Message.user_id != self.pr.user_id)
return qry.all()
return self._query_need_review(message_tree_state.State.GROWING, self.cfg.num_reviews_reply, False, lang)
_sql_find_incomplete_rankings = """
-- find incomplete rankings
@@ -748,17 +756,28 @@ WHERE mts.active -- only consider active trees
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id, m.role
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
_sql_find_incomplete_rankings_ex = f"""
-- incomplete rankings but exclude of current user
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
SELECT ir.* FROM incomplete_rankings ir
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
"""
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
r = self.db.execute(
text(self._sql_find_incomplete_rankings),
text(self._sql_find_incomplete_rankings_ex),
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
"lang": lang,
"user_id": user_id,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
@@ -779,17 +798,20 @@ WHERE mts.active -- only consider active trees
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
AND COUNT(c.id) FILTER (WHERE c.user_id = :user_id) = 0 -- without reply by user
"""
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
"""Query parent messages that have not reached the maximum number of replies."""
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
r = self.db.execute(
text(self._sql_find_extendible_parents),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
"user_id": user_id,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
@@ -813,12 +835,14 @@ HAVING COUNT(m.id) < mts.goal_tree_size
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
"""Query size of active message trees in growing state."""
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
r = self.db.execute(
text(self._sql_find_extendible_trees),
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
"user_id": user_id,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
@@ -26,7 +26,6 @@ class CommitMode(IntEnum):
* managed_tx_method and async_managed_tx_method methods are decorators functions
* to be used on class functions. It expects the Class to have a 'db' Session object
* initialised
* TODO: tx method decorator for non class methods
"""
+1
View File
@@ -7,6 +7,7 @@ pushd "$parent_path/../../backend"
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True
export DEBUG_ALLOW_DUPLICATE_TASKS=True
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
uvicorn main:app --reload --port 8080 --host 0.0.0.0