Add lang-tag based task selection (lang-separation) (#863)

* lang based task selection

* use BCP 47 instead of ISO 639-1

* add Field(None, nullable=True)

* update migration script down_revision
This commit is contained in:
Andreas Köpf
2023-01-20 19:58:33 +01:00
committed by GitHub
parent 70fc80aa08
commit 2d21b65ed0
9 changed files with 90 additions and 26 deletions
@@ -0,0 +1,29 @@
"""use 'en' instead 'en-US' as default lang
Revision ID: 160ac010efcc
Revises: 4f26fec4d204
Create Date: 2023-01-20 16:50:00
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "160ac010efcc"
down_revision = "4f26fec4d204"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "lang")
op.add_column("message", sa.Column("lang", sa.String(length=32), server_default="en", nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("message", "lang")
op.add_column("message", sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False))
# ### end Alembic commands ###
+2
View File
@@ -128,6 +128,7 @@ if settings.DEBUG_USE_SEED_DATA:
user_message_id: str
parent_message_id: Optional[str]
text: str
lang: Optional[str]
role: str
tree_state: Optional[message_tree_state.State]
@@ -184,6 +185,7 @@ if settings.DEBUG_USE_SEED_DATA:
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(
msg.text,
msg.lang,
msg.task_message_id,
msg.user_message_id,
review_count=5,
+3 -2
View File
@@ -39,7 +39,7 @@ def request_task(
pr.ensure_user_is_enabled()
tm = TreeManager(db, pr)
task, message_tree_id, parent_message_id = tm.next_task(request.type)
task, message_tree_id, parent_message_id = tm.next_task(desired_task_type=request.type, lang=request.lang)
pr.task_repository.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
@@ -54,6 +54,7 @@ def request_task(
def tasks_availability(
*,
user: Optional[protocol_schema.User] = None,
lang: Optional[str] = "en",
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
):
@@ -62,7 +63,7 @@ def tasks_availability(
try:
pr = PromptRepository(db, api_client, client_user=user)
tm = TreeManager(db, pr)
return tm.determine_task_availability()
return tm.determine_task_availability(lang)
except OasstError:
raise
+4 -2
View File
@@ -10,6 +10,7 @@ def prepare_message(m: Message) -> protocol.Message:
frontend_message_id=m.frontend_message_id,
parent_id=m.parent_id,
text=m.text,
lang=m.lang,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
@@ -22,10 +23,11 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
def prepare_conversation_message_list(messages: list[Message]) -> list[protocol.ConversationMessage]:
return [
protocol.ConversationMessage(
text=message.text,
is_assistant=(message.role == "assistant"),
id=message.id,
frontend_message_id=message.frontend_message_id,
text=message.text,
lang=message.lang,
is_assistant=(message.role == "assistant"),
)
for message in messages
]
+1 -1
View File
@@ -38,7 +38,7 @@ class Message(SQLModel, table=True):
payload: Optional[PayloadContainer] = Field(
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
)
lang: str = Field(nullable=False, max_length=200, default="en-US")
lang: str = Field(sa_column=sa.Column(sa.String(32), server_default="en", nullable=False))
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
@@ -85,6 +85,7 @@ class PromptRepository:
task_id: UUID,
role: str,
payload: db_payload.MessagePayload,
lang: str,
payload_type: str = None,
depth: int = 0,
review_count: int = 0,
@@ -107,6 +108,7 @@ class PromptRepository:
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
lang=lang,
depth=depth,
review_count=review_count,
review_result=review_result,
@@ -146,6 +148,7 @@ class PromptRepository:
def store_text_reply(
self,
text: str,
lang: str,
frontend_message_id: str,
user_frontend_message_id: str,
review_count: int = 0,
@@ -209,6 +212,7 @@ class PromptRepository:
task_id=task.id,
role=role,
payload=db_payload.MessagePayload(text=text),
lang=lang or "en",
depth=depth,
review_count=review_count,
review_result=review_result,
+43 -21
View File
@@ -190,14 +190,18 @@ class TreeManager:
return task_count_by_type
def determine_task_availability(self) -> dict[protocol_schema.TaskRequestType, int]:
def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()
num_active_trees = self.query_num_active_trees()
extendible_parents = self.query_extendible_parents()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
incomplete_rankings = self.query_incomplete_rankings()
if not lang:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")
num_active_trees = self.query_num_active_trees(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
return self._determine_task_availability_internal(
num_active_trees=num_active_trees,
@@ -208,23 +212,29 @@ class TreeManager:
)
def next_task(
self, desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random
self,
desired_task_type: protocol_schema.TaskRequestType = protocol_schema.TaskRequestType.random,
lang: str = "en",
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
logger.debug("TreeManager.next_task()")
logger.debug(f"TreeManager.next_task({desired_task_type=}, {lang=})")
self.pr.ensure_user_is_enabled()
num_active_trees = self.query_num_active_trees()
prompts_need_review = self.query_prompts_need_review()
replies_need_review = self.query_replies_need_review()
extendible_parents = self.query_extendible_parents()
if not lang:
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")
incomplete_rankings = self.query_incomplete_rankings()
num_active_trees = self.query_num_active_trees(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
replies_need_review = self.query_replies_need_review(lang=lang)
extendible_parents = self.query_extendible_parents(lang=lang)
incomplete_rankings = self.query_incomplete_rankings(lang=lang)
if not self.cfg.rank_prompter_replies:
incomplete_rankings = list(filter(lambda r: r.role == "assistant", incomplete_rankings))
active_tree_sizes = self.query_extendible_trees()
active_tree_sizes = self.query_extendible_trees(lang=lang)
# determine type of task to generate
num_missing_replies = sum(x.remaining_messages for x in active_tree_sizes)
@@ -458,6 +468,7 @@ class TreeManager:
# here we store the text reply in the database
message = pr.store_text_reply(
text=interaction.text,
lang=interaction.lang,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
@@ -665,7 +676,7 @@ class TreeManager:
# calculate acceptance based on spam label
return np.mean([1 - l.labels[protocol_schema.TextLabel.spam] for l in labels])
def query_prompts_need_review(self) -> list[Message]:
def query_prompts_need_review(self, lang: str) -> list[Message]:
"""
Select initial prompt messages with less then required rankings in active message tree
(active == True in message_tree_state)
@@ -682,6 +693,7 @@ class TreeManager:
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_initial_prompt,
Message.parent_id.is_(None),
Message.lang == lang,
)
)
@@ -690,7 +702,7 @@ class TreeManager:
return qry.all()
def query_replies_need_review(self) -> list[Message]:
def query_replies_need_review(self, lang: str) -> list[Message]:
"""
Select child messages (parent_id IS NOT NULL) with less then required rankings
in active message tree (active == True in message_tree_state)
@@ -707,6 +719,7 @@ class TreeManager:
not_(Message.deleted),
Message.review_count < self.cfg.num_reviews_reply,
Message.parent_id.is_not(None),
Message.lang == lang,
)
)
@@ -724,13 +737,14 @@ FROM message_tree_state mts
WHERE mts.active -- only consider active trees
AND mts.state = :ranking_state -- message tree must be in ranking state
AND m.review_result -- must be reviewed
AND m.lang = :lang -- matches lang
AND NOT m.deleted -- not deleted
AND m.parent_id IS NOT NULL -- ignore initial prompts
GROUP BY m.parent_id, m.role
HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
"""
def query_incomplete_rankings(self) -> list[IncompleteRankingsRow]:
def query_incomplete_rankings(self, lang: str) -> list[IncompleteRankingsRow]:
"""Query parents which have childern that need further rankings"""
r = self.db.execute(
@@ -738,6 +752,7 @@ HAVING COUNT(m.id) > 1 and MIN(m.ranking_count) < :num_required_rankings
{
"num_required_rankings": self.cfg.num_required_rankings,
"ranking_state": message_tree_state.State.RANKING,
"lang": lang,
},
)
return [IncompleteRankingsRow.from_orm(x) for x in r.all()]
@@ -753,13 +768,14 @@ WHERE mts.active -- only consider active trees
AND NOT m.deleted -- ignore deleted messages as parents
AND m.depth < mts.max_depth -- ignore leaf nodes as parents
AND m.review_result -- parent node must have positive review
AND m.lang = :lang -- parent matches lang
AND NOT coalesce(c.deleted, FALSE) -- don't count deleted children
AND (c.review_result OR coalesce(c.review_count, 0) < :num_reviews_reply) -- don't count children with negative review but count elements under review
GROUP BY m.id, m.role, m.depth, m.message_tree_id, mts.max_children_count
HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
"""
def query_extendible_parents(self) -> list[ExtendibleParentRow]:
def query_extendible_parents(self, lang: str) -> list[ExtendibleParentRow]:
"""Query parent messages that have not reached the maximum number of replies."""
r = self.db.execute(
@@ -767,6 +783,7 @@ HAVING COUNT(c.id) < mts.max_children_count -- below maximum number of children
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
},
)
return [ExtendibleParentRow.from_orm(x) for x in r.all()]
@@ -787,7 +804,7 @@ GROUP BY m.message_tree_id, mts.goal_tree_size
HAVING COUNT(m.id) < mts.goal_tree_size
"""
def query_extendible_trees(self) -> list[ActiveTreeSizeRow]:
def query_extendible_trees(self, lang: str) -> list[ActiveTreeSizeRow]:
"""Query size of active message trees in growing state."""
r = self.db.execute(
@@ -795,6 +812,7 @@ HAVING COUNT(m.id) < mts.goal_tree_size
{
"growing_state": message_tree_state.State.GROWING,
"num_reviews_reply": self.cfg.num_reviews_reply,
"lang": lang,
},
)
return [ActiveTreeSizeRow.from_orm(x) for x in r.all()]
@@ -894,8 +912,12 @@ INNER JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Ranki
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=:s})")
self._insert_default_state(id, state=state)
def query_num_active_trees(self) -> int:
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
def query_num_active_trees(self, lang: str) -> int:
query = (
self.db.query(func.count(MessageTreeState.message_tree_id))
.join(Message, MessageTreeState.message_tree_id == Message.id)
.filter(MessageTreeState.active, Message.lang == lang)
)
return query.scalar()
def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
@@ -43,6 +43,7 @@ class ConversationMessage(BaseModel):
id: Optional[UUID] = None
frontend_message_id: Optional[str] = None
text: str
lang: Optional[str] # BCP 47
is_assistant: bool
@@ -72,6 +73,7 @@ class TaskRequest(BaseModel):
# this is optional. https://github.com/pydantic/pydantic/issues/1270
user: Optional[User] = Field(None, nullable=True)
collective: bool = False
lang: Optional[str] = Field(None, nullable=True) # BCP 47
class TaskAck(BaseModel):
@@ -266,6 +268,7 @@ class TextReplyToMessage(Interaction):
message_id: str
user_message_id: str
text: constr(min_length=1, strip_whitespace=True)
lang: Optional[str] # BCP 47
class MessageRating(Interaction):
@@ -73,6 +73,7 @@ async def test_can_post_interaction(oasst_api_client_mocked: OasstApiClient):
message_id="123",
user_message_id="321",
text="This is my reply",
lang="en",
user=protocol_schema.User(
id="123",
display_name="lomz",