mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
907 avoid duplicate labeling & ranking tasks (#923)
* store message_id in message_reactions and task_id in text_labels * exclude tasks to which users already responded to * remove test code * fix join in find_incomplete_rankings_ex
This commit is contained in:
+34
@@ -0,0 +1,34 @@
|
||||
"""add message_id to message_reaction
|
||||
|
||||
Revision ID: 8ba17b5f467a
|
||||
Revises: 160ac010efcc
|
||||
Create Date: 2023-01-24 11:34:42.167575
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ba17b5f467a"
|
||||
down_revision = "160ac010efcc"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("message_reaction", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.create_index(op.f("ix_message_reaction_message_id"), "message_reaction", ["message_id"], unique=False)
|
||||
op.add_column("text_labels", sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.create_index(op.f("ix_text_labels_task_id"), "text_labels", ["task_id"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_text_labels_task_id"), table_name="text_labels")
|
||||
op.drop_column("text_labels", "task_id")
|
||||
op.drop_index(op.f("ix_message_reaction_message_id"), table_name="message_reaction")
|
||||
op.drop_column("message_reaction", "message_id")
|
||||
# ### end Alembic commands ###
|
||||
@@ -81,6 +81,7 @@ class Settings(BaseSettings):
|
||||
Path(__file__).parent.parent / "test_data/realistic/realistic_seed_data.json"
|
||||
)
|
||||
DEBUG_ALLOW_SELF_LABELING: bool = False # allow users to label their own messages
|
||||
DEBUG_ALLOW_DUPLICATE_TASKS: bool = False # offer users tasks to which they already responded
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
|
||||
@@ -26,3 +26,4 @@ class MessageReaction(SQLModel, table=True):
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
message_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
|
||||
@@ -27,3 +27,4 @@ class TextLabels(SQLModel, table=True):
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
|
||||
)
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
task_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
|
||||
@@ -245,7 +245,7 @@ class PromptRepository:
|
||||
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(message.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=message.id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
@@ -295,7 +295,7 @@ class PromptRepository:
|
||||
ranking_parent_id=task_payload.ranking_parent_id,
|
||||
message_tree_id=task_payload.message_tree_id,
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=parent_msg.id)
|
||||
self.journal.log_ranking(task, message_id=parent_msg.id, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
@@ -313,9 +313,8 @@ class PromptRepository:
|
||||
reaction_payload = db_payload.RankingReactionPayload(
|
||||
ranking=ranking.ranking, ranked_message_ids=ranked_message_ids
|
||||
)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task_id=task.id, payload=reaction_payload, message_id=None)
|
||||
# self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
@@ -366,7 +365,9 @@ class PromptRepository:
|
||||
return message_embedding
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
def insert_reaction(
|
||||
self, task_id: UUID, payload: db_payload.ReactionPayload, message_id: Optional[UUID]
|
||||
) -> MessageReaction:
|
||||
self.ensure_user_is_enabled()
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
@@ -376,6 +377,7 @@ class PromptRepository:
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
message_id=message_id,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
return reaction
|
||||
@@ -441,6 +443,7 @@ class PromptRepository:
|
||||
user_id=self.user_id,
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
task_id=task.id if task else None,
|
||||
)
|
||||
|
||||
if message_id:
|
||||
|
||||
@@ -682,57 +682,65 @@ class TreeManager:
|
||||
# calculate acceptance based on spam label
|
||||
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
|
||||
|
||||
def query_prompts_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
def _query_need_review(
|
||||
self, state: message_tree_state.State, required_reviews: int, root: bool, lang: str
|
||||
) -> list[Message]:
|
||||
|
||||
qry = (
|
||||
need_review = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
|
||||
MessageTreeState.state == state,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_initial_prompt,
|
||||
Message.parent_id.is_(None),
|
||||
Message.review_count < required_reviews,
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
if root:
|
||||
need_review = need_review.filter(Message.parent_id.is_(None))
|
||||
else:
|
||||
need_review = need_review.filter(Message.parent_id.is_not(None))
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
need_review = need_review.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
if settings.DEBUG_ALLOW_DUPLICATE_TASKS:
|
||||
qry = need_review
|
||||
else:
|
||||
user_id = self.pr.user_id
|
||||
need_review = need_review.cte(name="need_review")
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_entity_from(need_review)
|
||||
.outerjoin(TextLabels, need_review.c.id == TextLabels.message_id)
|
||||
.group_by(need_review)
|
||||
.having(
|
||||
func.count(TextLabels.id).filter(TextLabels.task_id.is_not(None), TextLabels.user_id == user_id)
|
||||
== 0
|
||||
)
|
||||
)
|
||||
|
||||
return qry.all()
|
||||
|
||||
def query_prompts_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select initial prompt messages with less then required rankings in active message tree
|
||||
(active == True in message_tree_state)
|
||||
"""
|
||||
return self._query_need_review(
|
||||
message_tree_state.State.INITIAL_PROMPT_REVIEW, self.cfg.num_reviews_initial_prompt, True, lang
|
||||
)
|
||||
|
||||
def query_replies_need_review(self, lang: str) -> list[Message]:
|
||||
"""
|
||||
Select child messages (parent_id IS NOT NULL) with less then required rankings
|
||||
in active message tree (active == True in message_tree_state)
|
||||
"""
|
||||
|
||||
qry = (
|
||||
self.db.query(Message)
|
||||
.select_from(MessageTreeState)
|
||||
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
|
||||
.filter(
|
||||
MessageTreeState.active,
|
||||
MessageTreeState.state == message_tree_state.State.GROWING,
|
||||
not_(Message.review_result),
|
||||
not_(Message.deleted),
|
||||
Message.review_count < self.cfg.num_reviews_reply,
|
||||
Message.parent_id.is_not(None),
|
||||
Message.lang == lang,
|
||||
)
|
||||
)
|
||||
|
||||
if not settings.DEBUG_ALLOW_SELF_LABELING:
|
||||
qry = qry.filter(Message.user_id != self.pr.user_id)
|
||||
|
||||
return qry.all()
|
||||
return self._query_need_review(message_tree_state.State.GROWING, self.cfg.num_reviews_reply, False, lang)
|
||||
|
||||
_sql_find_incomplete_rankings = """
|
||||
-- find incomplete rankings
|
||||
@@ -748,17 +756,28 @@ WHERE mts.active -- only consider active trees
|
||||
AND m.parent_id IS NOT NULL -- ignore initial prompts
|
||||
GROUP BY m.parent_id, m.role
|
||||
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
|
||||
"""
|
||||
|
||||
_sql_find_incomplete_rankings_ex = f"""
|
||||
-- incomplete rankings but exclude of current user
|
||||
WITH incomplete_rankings AS ({_sql_find_incomplete_rankings})
|
||||
SELECT ir.* FROM incomplete_rankings ir
|
||||
LEFT JOIN message_reaction mr ON ir.parent_id = mr.message_id AND mr.payload_type = 'RankingReactionPayload'
|
||||
GROUP BY ir.parent_id, ir.role, ir.children_count, ir.child_min_ranking_count, ir.completed_rankings
|
||||
HAVING(COUNT(mr.message_id) FILTER (WHERE mr.user_id = :user_id) = 0)
|
||||
"""
|
||||
|
||||
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
|
||||
"""Query parents which have childern that need further rankings"""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_incomplete_rankings),
|
||||
text(self._sql_find_incomplete_rankings_ex),
|
||||
{
|
||||
"num_required_rankings": self.cfg.num_required_rankings,
|
||||
"ranking_state": message_tree_state.State.RANKING,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
|
||||
@@ -779,17 +798,20 @@ WHERE mts.active -- only consider active trees
|
||||
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
|
||||
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
|
||||
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
|
||||
AND COUNT(c.id) FILTER (WHERE c.user_id = :user_id) = 0 -- without reply by user
|
||||
"""
|
||||
|
||||
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
|
||||
"""Query parent messages that have not reached the maximum number of replies."""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_parents),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
|
||||
@@ -813,12 +835,14 @@ HAVING COUNT(m.id) < mts.goal_tree_size
|
||||
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
|
||||
"""Query size of active message trees in growing state."""
|
||||
|
||||
user_id = self.pr.user_id if not settings.DEBUG_ALLOW_DUPLICATE_TASKS else None
|
||||
r = self.db.execute(
|
||||
text(self._sql_find_extendible_trees),
|
||||
{
|
||||
"growing_state": message_tree_state.State.GROWING,
|
||||
"num_reviews_reply": self.cfg.num_reviews_reply,
|
||||
"lang": lang,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
|
||||
|
||||
@@ -26,7 +26,6 @@ class CommitMode(IntEnum):
|
||||
* managed_tx_method and async_managed_tx_method methods are decorators functions
|
||||
* to be used on class functions. It expects the Class to have a 'db' Session object
|
||||
* initialised
|
||||
* TODO: tx method decorator for non class methods
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ pushd "$parent_path/../../backend"
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
export DEBUG_ALLOW_DUPLICATE_TASKS=True
|
||||
export DEBUG_SKIP_EMBEDDING_COMPUTATION=True
|
||||
|
||||
uvicorn main:app --reload --port 8080 --host 0.0.0.0
|
||||
|
||||
Reference in New Issue
Block a user