mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
from enum import Enum
|
|
from uuid import UUID
|
|
|
|
import sqlalchemy as sa
|
|
import sqlalchemy.dialects.postgresql as pg
|
|
from sqlmodel import Field, SQLModel
|
|
|
|
|
|
class State(str, Enum):
|
|
"""States of the Open-Assistant message tree state machine."""
|
|
|
|
INITIAL_PROMPT_REVIEW = "initial_prompt_review"
|
|
"""In this state the message tree consists only of a single inital prompt root node.
|
|
Initial prompt labeling tasks will determine if the tree goes into `growing` or
|
|
`aborted_low_grade` state."""
|
|
|
|
GROWING = "growing"
|
|
"""Assistant & prompter human demonstrations are collected. Concurrently labeling tasks
|
|
are handed out to check if the quality of the replies surpasses the minimum acceptable
|
|
quality.
|
|
When the required number of messages passing the initial labelling-quality check has been
|
|
collected the tree will enter `ranking`. If too many poor-quality labelling responses
|
|
are received the tree can also enter the `aborted_low_grade` state."""
|
|
|
|
RANKING = "ranking"
|
|
"""The tree has been successfully populated with the desired number of messages. Ranking
|
|
tasks are now handed out for all nodes with more than one child."""
|
|
|
|
READY_FOR_SCORING = "ready_for_scoring"
|
|
"""Required ranking responses have been collected and the scoring algorithm can now
|
|
compute the aggergated ranking scores that will appear in the dataset."""
|
|
|
|
READY_FOR_EXPORT = "ready_for_export"
|
|
"""The Scoring algorithm computed rankings scores for all childern. The message tree can be
|
|
exported as part of an Open-Assistant message tree dataset."""
|
|
|
|
SCORING_FAILED = "scoring_failed"
|
|
"""An exception occured in the scoring algorithm."""
|
|
|
|
ABORTED_LOW_GRADE = "aborted_low_grade"
|
|
"""The system received too many bad reviews and stopped handing out tasks for this message tree."""
|
|
|
|
HALTED_BY_MODERATOR = "halted_by_moderator"
|
|
"""A moderator decided to manually halt the message tree construction process."""
|
|
|
|
|
|
VALID_STATES = (
|
|
State.INITIAL_PROMPT_REVIEW,
|
|
State.GROWING,
|
|
State.RANKING,
|
|
State.READY_FOR_SCORING,
|
|
State.READY_FOR_EXPORT,
|
|
State.ABORTED_LOW_GRADE,
|
|
)
|
|
|
|
TERMINAL_STATES = (State.READY_FOR_EXPORT, State.ABORTED_LOW_GRADE, State.SCORING_FAILED, State.HALTED_BY_MODERATOR)
|
|
|
|
|
|
class MessageTreeState(SQLModel, table=True):
|
|
__tablename__ = "message_tree_state"
|
|
|
|
message_tree_id: UUID = Field(
|
|
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), primary_key=True)
|
|
)
|
|
goal_tree_size: int = Field(nullable=False)
|
|
max_depth: int = Field(nullable=False)
|
|
max_children_count: int = Field(nullable=False)
|
|
state: str = Field(nullable=False, max_length=128, index=True)
|
|
active: bool = Field(nullable=False, index=True)
|