fix breaking api changes

This commit is contained in:
Alex Ott
2022-12-30 22:52:57 -08:00
37 changed files with 2671 additions and 16048 deletions
+8 -2
View File
@@ -35,6 +35,8 @@ We are not going to stop at replicating ChatGPT. We want to build the assistant
### Slide Decks
[Vision & Roadmap](https://docs.google.com/presentation/d/1n7IrAOVOqwdYgiYrXc8Sj0He8krn5MVZO_iLkCjTtu0/edit?usp=sharing)
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
## How can you help?
@@ -43,9 +45,11 @@ All open source projects begins with people like you. Open source is the belief
## Im in! Now what?
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord), this is for work coordination.
[and / or the YK Discord Server](https://ykilcher.com/discord)
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has a dedicated channel and is more public.
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a dedicated, but not as active, channel.
[Visit the Notion](https://ykilcher.com/open-assistant)
@@ -61,6 +65,8 @@ has assigned the issue to you, start working on it.
If the issue is currently unclear but you are interested, please post in
Discord and someone can help clarify the issue with more detail.
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of the system architecture, and other documentation.
### Submitting Work
We're all working on different parts of Open Assistant together. To make
@@ -0,0 +1,339 @@
# -*- coding: utf-8 -*-
"""name changes: person->user, post->message, work_package->task
Revision ID: abb47e9d145a
Revises: 73ce3675c1f5
Create Date: 2022-12-30 20:54:49.880568
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "abb47e9d145a"
down_revision = "73ce3675c1f5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# clear DB
op.execute("DELETE FROM journal;")
op.execute("DELETE FROM work_package;")
op.execute("DELETE FROM post_reaction;")
op.execute("DELETE FROM post;")
op.execute("DELETE FROM person_stats;")
op.execute("DELETE FROM person;")
op.execute("DELETE FROM text_labels;")
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"user",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("auth_method", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_user_username", "user", ["api_client_id", "username", "auth_method"], unique=True)
op.create_table(
"message",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("parent_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("role", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("lang", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_message_frontend_message_id", "message", ["api_client_id", "frontend_message_id"], unique=True)
op.create_index(op.f("ix_message_message_tree_id"), "message", ["message_tree_id"], unique=False)
op.create_index(op.f("ix_message_task_id"), "message", ["task_id"], unique=False)
op.create_index(op.f("ix_message_user_id"), "message", ["user_id"], unique=False)
op.create_table(
"task",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("expiry_date", sa.DateTime(), nullable=True),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("ack", sa.Boolean(), nullable=True),
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("parent_message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_task_user_id"), "task", ["user_id"], unique=False)
op.create_table(
"user_stats",
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("leader_score", sa.Integer(), nullable=False),
sa.Column("reactions", sa.Integer(), nullable=False),
sa.Column("messages", sa.Integer(), nullable=False),
sa.Column("upvotes", sa.Integer(), nullable=False),
sa.Column("downvotes", sa.Integer(), nullable=False),
sa.Column("task_reward", sa.Integer(), nullable=False),
sa.Column("compare_wins", sa.Integer(), nullable=False),
sa.Column("compare_losses", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("user_id"),
)
op.create_table(
"message_reaction",
sa.Column("task_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.ForeignKeyConstraint(
["api_client_id"],
["api_client.id"],
),
sa.ForeignKeyConstraint(
["task_id"],
["task.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("task_id", "user_id"),
)
op.drop_constraint("text_labels_post_id_fkey", "text_labels", type_="foreignkey")
op.drop_constraint("journal_post_id_fkey", "journal", type_="foreignkey")
op.drop_constraint("journal_person_id_fkey", "journal", type_="foreignkey")
op.drop_table("post_reaction")
op.drop_index("ix_post_frontend_post_id", table_name="post")
op.drop_index("ix_post_person_id", table_name="post")
op.drop_index("ix_post_thread_id", table_name="post")
op.drop_index("ix_post_workpackage_id", table_name="post")
op.drop_table("post")
op.drop_index("ix_work_package_person_id", table_name="work_package")
op.drop_table("work_package")
op.drop_table("person_stats")
op.drop_index("ix_person_username", table_name="person")
op.drop_table("person")
op.add_column("journal", sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.add_column("journal", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.drop_index("ix_journal_person_id", table_name="journal")
op.create_index(op.f("ix_journal_user_id"), "journal", ["user_id"], unique=False)
op.create_foreign_key(None, "journal", "user", ["user_id"], ["id"])
op.create_foreign_key(None, "journal", "message", ["message_id"], ["id"])
op.drop_column("journal", "person_id")
op.drop_column("journal", "post_id")
op.add_column("text_labels", sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True))
op.create_foreign_key(None, "text_labels", "message", ["message_id"], ["id"])
op.drop_column("text_labels", "post_id")
# ### end Alembic commands ###
def downgrade() -> None:
# clear DB
op.execute("DELETE FROM journal;")
op.execute("DELETE FROM message_reaction;")
op.execute("DELETE FROM task;")
op.execute("DELETE FROM message;")
op.execute("DELETE FROM user_stats;")
op.execute('DELETE FROM "user";')
op.execute("DELETE FROM text_labels;")
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("text_labels", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.drop_constraint("text_labels_message_id_fkey", "text_labels", type_="foreignkey")
op.drop_column("text_labels", "message_id")
op.add_column("journal", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.add_column("journal", sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True))
op.drop_constraint("journal_message_id_fkey", "journal", type_="foreignkey")
op.drop_constraint("journal_user_id_fkey", "journal", type_="foreignkey")
op.drop_index(op.f("ix_journal_user_id"), table_name="journal")
op.create_index("ix_journal_person_id", "journal", ["person_id"], unique=False)
op.drop_column("journal", "message_id")
op.drop_column("journal", "user_id")
op.create_table(
"person",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column("username", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.Column("display_name", sa.VARCHAR(length=256), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("auth_method", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="person_api_client_id_fkey"),
sa.PrimaryKeyConstraint("id", name="person_pkey"),
)
op.create_table(
"person_stats",
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("leader_score", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column(
"modified_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("reactions", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("posts", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("upvotes", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("downvotes", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("work_reward", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("compare_wins", sa.INTEGER(), autoincrement=False, nullable=False),
sa.Column("compare_losses", sa.INTEGER(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="person_stats_person_id_fkey"),
sa.PrimaryKeyConstraint("person_id", name="person_stats_pkey"),
)
op.create_table(
"work_package",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("expiry_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("done", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.Column("ack", sa.BOOLEAN(), autoincrement=False, nullable=True),
sa.Column("frontend_ref_post_id", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("parent_post_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("collective", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="work_package_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="work_package_person_id_fkey"),
sa.PrimaryKeyConstraint("id", name="work_package_pkey"),
)
op.create_index("ix_work_package_person_id", "work_package", ["person_id"], unique=False)
op.create_table(
"post",
sa.Column(
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
),
sa.Column("parent_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("workpackage_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("role", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
sa.Column("frontend_post_id", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
sa.Column("depth", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("children_count", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_person_id_fkey"),
sa.PrimaryKeyConstraint("id", name="post_pkey"),
)
op.create_index("ix_post_workpackage_id", "post", ["workpackage_id"], unique=False)
op.create_index("ix_post_thread_id", "post", ["thread_id"], unique=False)
op.create_index("ix_post_person_id", "post", ["person_id"], unique=False)
op.create_index("ix_post_frontend_post_id", "post", ["api_client_id", "frontend_post_id"], unique=False)
op.create_table(
"post_reaction",
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column(
"created_date",
postgresql.TIMESTAMP(),
server_default=sa.text("CURRENT_TIMESTAMP"),
autoincrement=False,
nullable=False,
),
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.Column("work_package_id", postgresql.UUID(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_reaction_api_client_id_fkey"),
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_reaction_person_id_fkey"),
sa.ForeignKeyConstraint(["work_package_id"], ["work_package.id"], name="post_reaction_work_package_id_fkey"),
)
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=False)
op.create_foreign_key("text_labels_post_id_fkey", "text_labels", "post", ["post_id"], ["id"])
op.create_foreign_key("journal_person_id_fkey", "journal", "person", ["person_id"], ["id"])
op.create_foreign_key("journal_post_id_fkey", "journal", "post", ["post_id"], ["id"])
op.drop_table("message_reaction")
op.drop_table("user_stats")
op.drop_index(op.f("ix_task_user_id"), table_name="task")
op.drop_table("task")
op.drop_index(op.f("ix_message_user_id"), table_name="message")
op.drop_index(op.f("ix_message_task_id"), table_name="message")
op.drop_index(op.f("ix_message_message_tree_id"), table_name="message")
op.drop_index("ix_message_frontend_message_id", table_name="message")
op.drop_table("message")
op.drop_index("ix_user_username", table_name="user")
op.drop_table("user")
# ### end Alembic commands ###
+61 -60
View File
@@ -67,10 +67,10 @@ if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
class DummyPost(pydantic.BaseModel):
task_post_id: str
user_post_id: str
parent_post_id: Optional[str]
class DummyMessage(pydantic.BaseModel):
task_message_id: str
user_message_id: str
parent_message_id: Optional[str]
text: str
role: str
@@ -81,96 +81,97 @@ if settings.DEBUG_USE_SEED_DATA:
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
dummy_posts = [
DummyPost(
task_post_id="de111fa8",
user_post_id="6f1d0711",
parent_post_id=None,
dummy_messages = [
DummyMessage(
task_message_id="de111fa8",
user_message_id="6f1d0711",
parent_message_id=None,
text="Hi!",
role="user",
role="prompter",
),
DummyPost(
task_post_id="74c381d4",
user_post_id="4a24530b",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="74c381d4",
user_message_id="4a24530b",
parent_message_id="6f1d0711",
text="Hello! How can I help you?",
role="assistant",
),
DummyPost(
task_post_id="3d5dc440",
user_post_id="a8c01c04",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="3d5dc440",
user_message_id="a8c01c04",
parent_message_id="4a24530b",
text="Do you have a recipe for potato soup?",
role="user",
role="prompter",
),
DummyPost(
task_post_id="643716c1",
user_post_id="f43a93b7",
parent_post_id="4a24530b",
DummyMessage(
task_message_id="643716c1",
user_message_id="f43a93b7",
parent_message_id="4a24530b",
text="Who were the 8 presidents before George Washington?",
role="user",
role="prompter",
),
DummyPost(
task_post_id="2e4e1e6",
user_post_id="c886920",
parent_post_id="6f1d0711",
DummyMessage(
task_message_id="2e4e1e6",
user_message_id="c886920",
parent_message_id="6f1d0711",
text="Hey buddy! How can I serve you?",
role="assistant",
),
DummyPost(
task_post_id="970c437d",
user_post_id="cec432cf",
parent_post_id=None,
DummyMessage(
task_message_id="970c437d",
user_message_id="cec432cf",
parent_message_id=None,
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
role="user",
role="prompter",
),
DummyPost(
task_post_id="6066118e",
user_post_id="4f85f637",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="6066118e",
user_message_id="4f85f637",
parent_message_id="cec432cf",
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
role="assistant",
),
DummyPost(
task_post_id="ba87780d",
user_post_id="0e276b98",
parent_post_id="cec432cf",
DummyMessage(
task_message_id="ba87780d",
user_message_id="0e276b98",
parent_message_id="cec432cf",
text="I'm unsure how to interpret this. Is it a riddle?",
role="assistant",
),
]
for p in dummy_posts:
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
if wp and not wp.ack:
logger.warning("Deleting unacknowledged seed data work package")
db.delete(wp)
wp = None
if not wp:
if p.parent_post_id is None:
wp = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
for msg in dummy_messages:
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = pr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
print("p.parent_post_id", p.parent_post_id)
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
wp = pr.store_task(
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
task = pr.store_task(
protocol_schema.AssistantReplyTask(
conversation=protocol_schema.Conversation(
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
)
),
thread_id=parent_post.thread_id,
parent_post_id=parent_post.id,
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
pr.bind_frontend_post_id(wp.id, p.task_post_id)
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
pr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
logger.info(
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
)
else:
logger.debug(f"seed data work_package found: {wp.id}")
logger.debug(f"seed data task found: {task.id}")
logger.info("Seed data check completed")
except Exception:
+51 -45
View File
@@ -18,8 +18,8 @@ router = APIRouter()
def generate_task(
request: protocol_schema.TaskRequest, pr: PromptRepository
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
thread_id = None
parent_post_id = None
message_tree_id = None
parent_message_id = None
match request.type:
case protocol_schema.TaskRequestType.random:
@@ -54,38 +54,42 @@ def generate_task(
task = protocol_schema.InitialPromptTask(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.user_reply:
logger.info("Generating a UserReplyTask.")
posts = pr.fetch_random_conversation("assistant")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
case protocol_schema.TaskRequestType.prompter_reply:
logger.info("Generating a PrompterReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.assistant_reply:
logger.info("Generating a AssistantReplyTask.")
posts = pr.fetch_random_conversation("user")
messages = [
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
for p in posts
messages = pr.fetch_random_conversation("prompter")
task_messages = [
protocol_schema.ConversationMessage(
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
)
for msg in messages
]
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
thread_id = posts[-1].thread_id
parent_post_id = posts[-1].id
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
message_tree_id = messages[-1].message_tree_id
parent_message_id = messages[-1].id
case protocol_schema.TaskRequestType.rank_initial_prompts:
logger.info("Generating a RankInitialPromptsTask.")
posts = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
case protocol_schema.TaskRequestType.rank_user_replies:
logger.info("Generating a RankUserRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
@@ -93,18 +97,18 @@ def generate_task(
for p in conversation
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankUserRepliesTask(
task = protocol_schema.RankPrompterRepliesTask(
conversation=protocol_schema.Conversation(
messages=messages,
messages=task_messages,
),
replies=replies,
)
case protocol_schema.TaskRequestType.rank_assistant_replies:
logger.info("Generating a RankAssistantRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
messages = [
task_messages = [
protocol_schema.ConversationMessage(
text=p.payload.payload.text,
is_assistant=(p.role == "assistant"),
@@ -113,7 +117,7 @@ def generate_task(
]
replies = [p.payload.payload.text for p in replies]
task = protocol_schema.RankAssistantRepliesTask(
conversation=protocol_schema.Conversation(messages=messages),
conversation=protocol_schema.Conversation(messages=task_messages),
replies=replies,
)
case _:
@@ -121,7 +125,7 @@ def generate_task(
logger.info(f"Generated {task=}.")
return task, thread_id, parent_post_id
return task, message_tree_id, parent_message_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@@ -138,8 +142,8 @@ def request_task(
try:
pr = PromptRepository(db, api_client, request.user)
task, thread_id, parent_post_id = generate_task(request, pr)
pr.store_task(task, thread_id, parent_post_id, request.collective)
task, message_tree_id, parent_message_id = generate_task(request, pr)
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
except OasstError:
raise
@@ -150,7 +154,7 @@ def request_task(
@router.post("/{task_id}/ack")
def acknowledge_task(
def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -166,9 +170,9 @@ def acknowledge_task(
try:
pr = PromptRepository(db, api_client, user=None)
# here we store the post id in the database for the task
# here we store the message id in the database for the task
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
except OasstError:
raise
@@ -179,7 +183,7 @@ def acknowledge_task(
@router.post("/{task_id}/nack")
def acknowledge_task_failure(
def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -201,7 +205,7 @@ def acknowledge_task_failure(
@router.post("/interaction")
def post_interaction(
def tasks_interaction(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
@@ -216,29 +220,31 @@ def post_interaction(
pr = PromptRepository(db, api_client, user=interaction.user)
match type(interaction):
case protocol_schema.TextReplyToPost:
case protocol_schema.TextReplyToMessage:
logger.info(
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
)
# here we store the text reply in the database
pr.store_text_reply(
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
text=interaction.text,
frontend_message_id=interaction.message_id,
user_frontend_message_id=interaction.user_message_id,
)
return protocol_schema.TaskDone()
case protocol_schema.PostRating:
case protocol_schema.MessageRating:
logger.info(
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
)
# here we store the rating in the database
pr.store_rating(interaction)
return protocol_schema.TaskDone()
case protocol_schema.PostRanking:
case protocol_schema.MessageRanking:
logger.info(
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
)
# TODO: check if the ranking is valid
@@ -262,5 +268,5 @@ def close_collective_task(
):
api_client = deps.api_auth(api_key, db)
pr = PromptRepository(db, api_client, user=None)
pr.close_task(close_task_request.post_id)
pr.close_task(close_task_request.message_id)
return protocol_schema.TaskDone()
+10 -10
View File
@@ -27,21 +27,21 @@ class OasstErrorCode(IntEnum):
TASK_GENERATION_FAILED = 1005
# 2000-3000: prompt_repository
INVALID_POST_ID = 2000
POST_NOT_FOUND = 2001
INVALID_FRONTEND_MESSAGE_ID = 2000
MESSAGE_NOT_FOUND = 2001
RATING_OUT_OF_RANGE = 2002
INVALID_RANKING_VALUE = 2003
INVALID_TASK_TYPE = 2004
USER_NOT_SPECIFIED = 2005
NO_THREADS_FOUND = 2006
NO_MESSAGE_TREE_FOUND = 2006
NO_REPLIES_FOUND = 2007
WORK_PACKAGE_NOT_FOUND = 2100
WORK_PACKAGE_EXPIRED = 2101
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
WORK_PACKAGE_ALREADY_UPDATED = 2103
WORK_PACKAGE_NOT_ACK = 2104
WORK_PACKAGE_ALREADY_DONE = 2105
WORK_PACKAGE_NOT_COLLECTIVE = 2106
TASK_NOT_FOUND = 2100
TASK_EXPIRED = 2101
TASK_PAYLOAD_TYPE_MISMATCH = 2102
TASK_ALREADY_UPDATED = 2103
TASK_NOT_ACK = 2104
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
class OasstError(Exception):
+38 -38
View File
@@ -3,7 +3,7 @@ import enum
from typing import Literal, Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
from oasst_backend.models import ApiClient, Journal, Task, User
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_shared.utils import utcnow
from pydantic import BaseModel
@@ -14,71 +14,71 @@ class JournalEventType(str, enum.Enum):
"""A label for a piece of text."""
user_created = "user_created"
text_reply_to_post = "text_reply_to_post"
post_rating = "post_rating"
post_ranking = "post_ranking"
text_reply_to_message = "text_reply_to_message"
message_rating = "message_rating"
message_ranking = "message_ranking"
@payload_type
class JournalEvent(BaseModel):
type: str
person_id: Optional[UUID]
post_id: Optional[UUID]
workpackage_id: Optional[UUID]
user_id: Optional[UUID]
message_id: Optional[UUID]
task_id: Optional[UUID]
task_type: Optional[str]
@payload_type
class TextReplyEvent(JournalEvent):
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
type: Literal[JournalEventType.text_reply_to_message] = JournalEventType.text_reply_to_message
length: int
role: str
@payload_type
class RatingEvent(JournalEvent):
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
type: Literal[JournalEventType.message_rating] = JournalEventType.message_rating
rating: int
@payload_type
class RankingEvent(JournalEvent):
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
type: Literal[JournalEventType.message_ranking] = JournalEventType.message_ranking
ranking: list[int]
class JournalWriter:
def __init__(self, db: Session, api_client: ApiClient, person: Person):
def __init__(self, db: Session, api_client: ApiClient, user: User):
self.db = db
self.api_client = api_client
self.person = person
self.person_id = self.person.id if self.person else None
self.user = user
self.user_id = self.user.id if self.user else None
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.text_reply_to_post,
task_type=task.payload_type,
event_type=JournalEventType.text_reply_to_message,
payload=TextReplyEvent(role=role, length=length),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_rating,
task_type=task.payload_type,
event_type=JournalEventType.message_rating,
payload=RatingEvent(rating=rating),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
return self.log(
task_type=work_package.payload_type,
event_type=JournalEventType.post_ranking,
task_type=task.payload_type,
event_type=JournalEventType.message_ranking,
payload=RankingEvent(ranking=ranking),
workpackage_id=work_package.id,
post_id=post_id,
task_id=task.id,
message_id=message_id,
)
def log(
@@ -87,8 +87,8 @@ class JournalWriter:
payload: JournalEvent,
task_type: str,
event_type: str = None,
workpackage_id: Optional[UUID] = None,
post_id: Optional[UUID] = None,
task_id: Optional[UUID] = None,
message_id: Optional[UUID] = None,
commit: bool = True,
) -> Journal:
if event_type is None:
@@ -97,22 +97,22 @@ class JournalWriter:
else:
event_type = type(payload).__name__
if payload.person_id is None:
payload.person_id = self.person_id
if payload.post_id is None:
payload.post_id = post_id
if payload.workpackage_id is None:
payload.workpackage_id = workpackage_id
if payload.user_id is None:
payload.user_id = self.user_id
if payload.message_id is None:
payload.message_id = message_id
if payload.task_id is None:
payload.task_id = task_id
if payload.task_type is None:
payload.task_type = task_type
entry = Journal(
person_id=self.person_id,
user_id=self.user_id,
api_client_id=self.api_client.id,
created_date=utcnow(),
event_type=event_type,
event_payload=PayloadContainer(payload=payload),
post_id=post_id,
message_id=message_id,
)
self.db.add(entry)
+10 -10
View File
@@ -1,20 +1,20 @@
# -*- coding: utf-8 -*-
from .api_client import ApiClient
from .journal import Journal, JournalIntegration
from .person import Person
from .person_stats import PersonStats
from .post import Post
from .post_reaction import PostReaction
from .message import Message
from .message_reaction import MessageReaction
from .task import Task
from .text_labels import TextLabels
from .work_package import WorkPackage
from .user import User
from .user_stats import UserStats
__all__ = [
"ApiClient",
"Person",
"PersonStats",
"Post",
"PostReaction",
"WorkPackage",
"User",
"UserStats",
"Message",
"MessageReaction",
"Task",
"TextLabels",
"Journal",
"JournalIntegration",
+8 -8
View File
@@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload):
@payload_type
class UserReplyPayload(TaskPayload):
type: Literal["user_reply"] = "user_reply"
class PrompterReplyPayload(TaskPayload):
type: Literal["prompter_reply"] = "prompter_reply"
conversation: protocol_schema.Conversation
hint: str | None
@@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload):
@payload_type
class PostPayload(BaseModel):
class MessagePayload(BaseModel):
text: str
@@ -56,13 +56,13 @@ class ReactionPayload(BaseModel):
@payload_type
class RatingReactionPayload(ReactionPayload):
type: Literal["post_rating"] = "post_rating"
type: Literal["message_rating"] = "message_rating"
rating: str
@payload_type
class RankingReactionPayload(ReactionPayload):
type: Literal["post_ranking"] = "post_ranking"
type: Literal["message_ranking"] = "message_ranking"
ranking: list[int]
@@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload):
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of prompter replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
@payload_type
+2 -2
View File
@@ -33,8 +33,8 @@ class Journal(SQLModel, table=True):
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
api_client_id: UUID = Field(foreign_key="api_client.id")
event_type: str = Field(nullable=False, max_length=200)
@@ -10,9 +10,9 @@ from sqlmodel import Field, Index, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class Post(SQLModel, table=True):
__tablename__ = "post"
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
class Message(SQLModel, table=True):
__tablename__ = "message"
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -20,12 +20,12 @@ class Post(SQLModel, table=True):
),
)
parent_id: UUID = Field(nullable=True)
thread_id: UUID = Field(nullable=False, index=True)
workpackage_id: UUID = Field(nullable=True, index=True)
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
role: str = Field(nullable=False, max_length=128)
message_tree_id: UUID = Field(nullable=False, index=True)
task_id: UUID = Field(nullable=True, index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
frontend_post_id: str = Field(max_length=200, nullable=False)
frontend_message_id: str = Field(max_length=200, nullable=False)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
@@ -10,14 +10,14 @@ from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class PostReaction(SQLModel, table=True):
__tablename__ = "post_reaction"
class MessageReaction(SQLModel, table=True):
__tablename__ = "message_reaction"
work_package_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
task_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
)
person_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
user_id: UUID = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
)
created_date: Optional[datetime] = Field(
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
@@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel
from .payload_column_type import PayloadContainer, payload_column_type
class WorkPackage(SQLModel, table=True):
__tablename__ = "work_package"
class Task(SQLModel, table=True):
__tablename__ = "task"
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -23,15 +23,15 @@ class WorkPackage(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
)
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
payload_type: str = Field(nullable=False, max_length=200)
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
ack: Optional[bool] = None
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
frontend_ref_post_id: Optional[str] = None
thread_id: Optional[UUID] = None
parent_post_id: Optional[UUID] = None
frontend_message_id: Optional[str] = None
message_tree_id: Optional[UUID] = None
parent_message_id: Optional[UUID] = None
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
@property
+3 -1
View File
@@ -21,5 +21,7 @@ class TextLabels(SQLModel, table=True):
)
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
text: str = Field(nullable=False, max_length=2**16)
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
message_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
)
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
@@ -8,9 +8,9 @@ import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, Index, SQLModel
class Person(SQLModel, table=True):
__tablename__ = "person"
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
class User(SQLModel, table=True):
__tablename__ = "user"
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
@@ -8,11 +8,11 @@ import sqlalchemy.dialects.postgresql as pg
from sqlmodel import Field, SQLModel
class PersonStats(SQLModel, table=True):
__tablename__ = "person_stats"
class UserStats(SQLModel, table=True):
__tablename__ = "user_stats"
person_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
user_id: Optional[UUID] = Field(
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
)
leader_score: int = 0
modified_date: Optional[datetime] = Field(
@@ -20,9 +20,9 @@ class PersonStats(SQLModel, table=True):
)
reactions: int = 0 # reactions sent by user
posts: int = 0 # posts sent by user
messages: int = 0 # messages sent by user
upvotes: int = 0 # received upvotes (form other users)
downvotes: int = 0 # received downvotes (from other users)
work_reward: int = 0 # reward for workpackage completions
compare_wins: int = 0 # num times user's post won compare tasks
compare_losses: int = 0 # num times users's post lost compare tasks
task_reward: int = 0 # reward for task completions
compare_wins: int = 0 # num times user's message won compare tasks
compare_losses: int = 0 # num times users's message lost compare tasks
+255 -251
View File
@@ -7,7 +7,7 @@ import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.journal_writer import JournalWriter
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func
@@ -17,247 +17,245 @@ class PromptRepository:
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
self.db = db
self.api_client = api_client
self.person = self.lookup_person(user)
self.person_id = self.person.id if self.person else None
self.journal = JournalWriter(db, api_client, self.person)
self.user = self.lookup_user(user)
self.user_id = self.user.id if self.user else None
self.journal = JournalWriter(db, api_client, self.user)
def lookup_person(self, user: protocol_schema.User) -> Person:
if not user:
def lookup_user(self, client_user: protocol_schema.User) -> User:
if not client_user:
return None
person: Person = (
self.db.query(Person)
user: User = (
self.db.query(User)
.filter(
Person.api_client_id == self.api_client.id,
Person.username == user.id,
Person.auth_method == user.auth_method,
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if person is None:
if user is None:
# user is unknown, create new record
person = Person(
username=user.id,
display_name=user.display_name,
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=user.auth_method,
auth_method=client_user.auth_method,
)
self.db.add(person)
self.db.add(user)
self.db.commit()
self.db.refresh(person)
elif user.display_name and user.display_name != person.display_name:
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
person.display_name = user.display_name
self.db.add(person)
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return person
return user
def validate_post_id(self, post_id: str) -> None:
if not isinstance(post_id, str):
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
if not post_id:
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
def validate_frontend_message_id(self, message_id: str) -> None:
if not isinstance(message_id, str):
raise OasstError(
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
)
if not message_id:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
self.validate_post_id(post_id)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
self.validate_frontend_message_id(frontend_message_id)
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.frontend_ref_post_id = post_id
work_pack.ack = True
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.add(task)
self.db.commit()
def acknowledge_task_failure(self, task_id):
# find work package
work_pack: WorkPackage = (
self.db.query(WorkPackage)
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
.first()
)
if work_pack is None:
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if work_pack.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if work_pack.done or work_pack.ack is not None:
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
work_pack.ack = False
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(work_pack)
self.db.add(task)
self.db.commit()
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
self.validate_post_id(frontend_post_id)
post: Post = (
self.db.query(Post)
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
self.validate_frontend_message_id(frontend_message_id)
message: Message = (
self.db.query(Message)
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
.one_or_none()
)
if fail_if_missing and post is None:
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
return post
if fail_if_missing and message is None:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND
)
return message
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
self.validate_post_id(post_id)
work_pack = (
self.db.query(WorkPackage)
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
self.validate_frontend_message_id(message_id)
task = (
self.db.query(Task)
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
.one_or_none()
)
return work_pack
return task
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
self.validate_post_id(post_id)
self.validate_post_id(user_post_id)
def store_text_reply(
self, text: str, frontend_message_id: str, user_frontend_message_id: str, role: str = None
) -> Message:
self.validate_frontend_message_id(frontend_message_id)
self.validate_frontend_message_id(user_frontend_message_id)
wp = self.fetch_workpackage_by_postid(post_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if wp is None:
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not wp.ack:
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
if wp.done:
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if task is None:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if not task.ack:
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
if task.done:
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
# If there's no parent post assume user started new conversation
role = "user"
# If there's no parent message assume user started new conversation
role = "prompter"
depth = 0
if wp.parent_post_id:
parent_post = self.fetch_post(wp.parent_post_id)
parent_post.children_count += 1
self.db.add(parent_post)
if task.parent_message_id:
parent_message = self.fetch_message(task.parent_message_id)
parent_message.children_count += 1
self.db.add(parent_message)
depth = parent_post.depth + 1
if parent_post.role == "assistant":
role = "user"
depth = parent_message.depth + 1
if parent_message.role == "assistant":
role = "prompter"
else:
role = "assistant"
# create reply post
new_post_id = uuid4()
user_post = self.insert_post(
post_id=new_post_id,
frontend_post_id=user_post_id,
parent_id=wp.parent_post_id,
thread_id=wp.thread_id or new_post_id,
workpackage_id=wp.id,
# create reply message
new_message_id = uuid4()
user_message = self.insert_message(
message_id=new_message_id,
frontend_message_id=user_frontend_message_id,
parent_id=task.parent_message_id,
message_tree_id=task.message_tree_id or new_message_id,
task_id=task.id,
role=role,
payload=db_payload.PostPayload(text=text),
payload=db_payload.MessagePayload(text=text),
depth=depth,
)
if not wp.collective:
wp.done = True
self.db.add(wp)
if not task.collective:
task.done = True
self.db.add(task)
self.db.commit()
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
return user_post
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
return user_message
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
work_package = self.fetch_workpackage_by_postid(rating.post_id)
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
if type(work_payload) != db_payload.RateSummaryPayload:
task = self.fetch_task_by_frontend_message_id(rating.message_id)
task_payload: db_payload.RateSummaryPayload = task.payload.payload
if type(task_payload) != db_payload.RateSummaryPayload:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"Task payload type mismatch: {type(task_payload)=} != {db_payload.RateSummaryPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
if rating.rating < task_payload.scale.min or rating.rating > task_payload.scale.max:
raise OasstError(
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
f"Invalid rating value: {rating.rating=} not in {task_payload.scale=}",
OasstErrorCode.RATING_OUT_OF_RANGE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
reaction = self.insert_reaction(post.id, reaction_payload)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
reaction = self.insert_reaction(message.id, reaction_payload)
if not task.collective:
task.done = True
self.db.add(task)
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
return reaction
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
if not work_package.collective:
work_package.done = True
self.db.add(work_package)
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
# fetch task
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
if not task.collective:
task.done = True
self.db.add(task)
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
task_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
task.payload.payload
)
match type(work_payload):
match type(task_payload):
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
# validate ranking
num_replies = len(work_payload.replies)
num_replies = len(task_payload.replies)
if sorted(ranking.ranking) != list(range(num_replies)):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
raise OasstError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
OasstErrorCode.INVALID_RANKING_VALUE,
)
# store reaction to post
# store reaction to message
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(work_package.id, reaction_payload)
# TODO: resolve post_id
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
reaction = self.insert_reaction(task.id, reaction_payload)
# TODO: resolve message_id
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
return reaction
case _:
raise OasstError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
)
def store_task(
self,
task: protocol_schema.Task,
thread_id: UUID = None,
parent_post_id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
payload: db_payload.TaskPayload
match type(task):
case protocol_schema.SummarizeStoryTask:
@@ -271,8 +269,8 @@ class PromptRepository:
case protocol_schema.InitialPromptTask:
payload = db_payload.InitialPromptPayload(hint=task.hint)
case protocol_schema.UserReplyTask:
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.PrompterReplyTask:
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
@@ -280,8 +278,8 @@ class PromptRepository:
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
case protocol_schema.RankPrompterRepliesTask:
payload = db_payload.RankPrompterRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies
)
@@ -293,81 +291,85 @@ class PromptRepository:
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
wp = self.insert_work_package(
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
task = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
assert wp.id == task.id
return wp
assert task.id == task.id
return task
def insert_work_package(
def insert_task(
self,
payload: db_payload.TaskPayload,
id: UUID = None,
thread_id: UUID = None,
parent_post_id: UUID = None,
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
) -> WorkPackage:
) -> Task:
c = PayloadContainer(payload=payload)
wp = WorkPackage(
task = Task(
id=id,
person_id=self.person_id,
user_id=self.user_id,
payload_type=type(payload).__name__,
payload=c,
api_client_id=self.api_client.id,
thread_id=thread_id,
parent_post_id=parent_post_id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
)
self.db.add(wp)
self.db.add(task)
self.db.commit()
self.db.refresh(wp)
return wp
self.db.refresh(task)
return task
def insert_post(
def insert_message(
self,
*,
post_id: UUID,
frontend_post_id: str,
message_id: UUID,
frontend_message_id: str,
parent_id: UUID,
thread_id: UUID,
workpackage_id: UUID,
message_tree_id: UUID,
task_id: UUID,
role: str,
payload: db_payload.PostPayload,
payload: db_payload.MessagePayload,
payload_type: str = None,
depth: int = 0,
) -> Post:
) -> Message:
if payload_type is None:
if payload is None:
payload_type = "null"
else:
payload_type = type(payload).__name__
post = Post(
id=post_id,
message = Message(
id=message_id,
parent_id=parent_id,
thread_id=thread_id,
workpackage_id=workpackage_id,
person_id=self.person_id,
message_tree_id=message_tree_id,
task_id=task_id,
user_id=self.user_id,
role=role,
frontend_post_id=frontend_post_id,
frontend_message_id=frontend_message_id,
api_client_id=self.api_client.id,
payload_type=payload_type,
payload=PayloadContainer(payload=payload),
depth=depth,
)
self.db.add(post)
self.db.add(message)
self.db.commit()
self.db.refresh(post)
return post
self.db.refresh(message)
return message
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
if self.person_id is None:
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
container = PayloadContainer(payload=payload)
reaction = PostReaction(
work_package_id=work_package_id,
person_id=self.person_id,
reaction = MessageReaction(
task_id=task_id,
user_id=self.user_id,
payload=container,
api_client_id=self.api_client.id,
payload_type=type(payload).__name__,
@@ -383,108 +385,110 @@ class PromptRepository:
text=text_labels.text,
labels=text_labels.labels,
)
if text_labels.has_post_id:
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
model.post_id = text_labels.post_id
if text_labels.has_message_id:
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
model.message_id = text_labels.message_id
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
"""
Loads all posts of a random thread.
Loads all messages of a random message_tree.
:param require_role: If set loads only thread which has
at least one post with given role.
:param require_role: If set loads only message_tree which has
at least one message with given role.
"""
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
if require_role:
distinct_threads = distinct_threads.filter(Post.role == require_role)
distinct_threads = distinct_threads.subquery()
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
distinct_message_trees = distinct_message_trees.subquery()
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
return thread_posts
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
return message_tree_messages
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
"""
Picks a random linear conversation starting from any root post
and ending somewhere in the thread, possibly at the root itself.
Picks a random linear conversation starting from any root message
and ending somewhere in the message_tree, possibly at the root itself.
:param last_post_role: If set will form a conversation ending with a post
:param last_message_role: If set will form a conversation ending with a message
created by this role. Necessary for the tasks like "user_reply" where
the user should reply as a human and hence the last message of the conversation
needs to have "assistant" role.
"""
thread_posts = self.fetch_random_thread(last_post_role)
if not thread_posts:
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
if last_post_role:
conv_posts = [p for p in thread_posts if p.role == last_post_role]
conv_posts = [random.choice(conv_posts)]
messages_tree = self.fetch_random_message_tree(last_message_role)
if not messages_tree:
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
if last_message_role:
conv_messages = [m for m in messages_tree if m.role == last_message_role]
conv_messages = [random.choice(conv_messages)]
else:
conv_posts = [random.choice(thread_posts)]
thread_posts = {p.id: p for p in thread_posts}
conv_messages = [random.choice(messages_tree)]
messages_tree = {m.id: m for m in messages_tree}
while True:
if not conv_posts[-1].parent_id:
if not conv_messages[-1].parent_id:
# reached the start of the conversation
break
parent_post = thread_posts[conv_posts[-1].parent_id]
conv_posts.append(parent_post)
parent_message = messages_tree[conv_messages[-1].parent_id]
conv_messages.append(parent_message)
return list(reversed(conv_posts))
return list(reversed(conv_messages))
def fetch_random_initial_prompts(self, size: int = 5):
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return posts
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
return messages
def fetch_thread(self, thread_id: UUID):
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
def fetch_message_tree(self, message_tree_id: UUID):
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
parent = self.db.query(Post.id).filter(Post.children_count > 1)
if post_role:
parent = parent.filter(Post.role == post_role)
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
parent = self.db.query(Message.id).filter(Message.children_count > 1)
if message_role:
parent = parent.filter(Message.role == message_role)
parent = parent.order_by(func.random()).limit(1)
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
replies = (
self.db.query(Message).filter(Message.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
)
if not replies:
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
thread = self.fetch_thread(replies[0].thread_id)
thread = {p.id: p for p in thread}
thread_posts = [thread[replies[0].parent_id]]
message_tree = self.fetch_message_tree(replies[0].message_tree_id)
message_tree = {p.id: p for p in message_tree}
conversation = [message_tree[replies[0].parent_id]]
while True:
if not thread_posts[-1].parent_id:
if not conversation[-1].parent_id:
# reached start of the conversation
break
parent_post = thread[thread_posts[-1].parent_id]
thread_posts.append(parent_post)
parent_message = message_tree[conversation[-1].parent_id]
conversation.append(parent_message)
thread_posts = reversed(thread_posts)
conversation = reversed(conversation)
return thread_posts, replies
return conversation, replies
def fetch_post(self, post_id: UUID) -> Optional[Post]:
return self.db.query(Post).filter(Post.id == post_id).one()
def fetch_message(self, message_id: UUID) -> Optional[Message]:
return self.db.query(Message).filter(Message.id == message_id).one()
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
self.validate_post_id(post_id)
wp = self.fetch_workpackage_by_postid(post_id)
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
self.validate_frontend_message_id(frontend_message_id)
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not wp:
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
if wp.expired:
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
if not allow_personal_tasks and not wp.collective:
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
if wp.done:
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
if not task:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
if task.done:
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
wp.done = True
self.db.add(wp)
task.done = True
self.db.add(task)
self.db.commit()
+9 -8
View File
@@ -10,16 +10,17 @@ from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
# TODO: Move to `protocol`?
class TaskType(str, enum.Enum):
"""Task types."""
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
@@ -44,10 +45,10 @@ class OasstApiClient:
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.user_reply: protocol_schema.UserReplyTask,
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
@@ -78,7 +79,7 @@ class OasstApiClient:
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"Fetch task response: {resp}")
logger.debug(f"RESP {resp}")
return self._parse_task(resp)
async def fetch_random_task(
@@ -88,10 +89,10 @@ class OasstApiClient:
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
async def ack_task(self, task_id: str | UUID, post_id: str):
async def ack_task(self, task_id: str | UUID, message_id: str):
"""Send an ACK for a task to the backend."""
logger.debug(f"ACK task {task_id} with post {post_id}")
req = protocol_schema.TaskAck(post_id=post_id)
logger.debug(f"ACK task {task_id} with post {message_id}")
req = protocol_schema.TaskAck(message_id=message_id)
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
async def nack_task(self, task_id: str | UUID, reason: str):
+1 -1
View File
@@ -23,6 +23,6 @@ class GuildSettings(BaseModel):
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
row = await cursor.fetchone()
if row is None:
raise ValueError("No settings found for this guild.")
return None
return cls.parse_obj(row)
+20 -19
View File
@@ -84,9 +84,9 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
logger.debug(f"Successful user input received: {event.content}")
# Send the response to the backend
reply = protocol_schema.TextReplyToPost(
post_id=str(msg_id),
user_post_id=str(event.message_id),
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
@@ -206,20 +206,20 @@ async def _send_task(
logger.debug("sending rank initial prompt task")
embed = _rank_initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_user_replies:
assert isinstance(task, protocol_schema.RankUserRepliesTask)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
embed = _rank_user_reply_embed(task)
embed = _rank_prompter_reply_embed(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
embed = _rank_assistant_reply_embed(task)
elif task.type == TaskRequestType.user_reply:
assert isinstance(task, protocol_schema.UserReplyTask)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
embed = _user_reply_embed(task)
embed = _prompter_reply_embed(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
@@ -258,28 +258,29 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.user_reply
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0
# Ranking tasks
elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask)
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = [int(r) for r in content.split(",")]
return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = [int(r) for r in content.split(",")]
return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
@@ -369,7 +370,7 @@ def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) ->
return embed
def _rank_user_reply_embed(task: protocol_schema.RankUserRepliesTask) -> hikari.Embed:
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank User Reply",
@@ -403,7 +404,7 @@ def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask)
return embed
def _user_reply_embed(task: protocol_schema.UserReplyTask) -> hikari.Embed:
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="User Reply",
+34 -34
View File
@@ -12,10 +12,10 @@ class TaskRequestType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
user_reply = "user_reply"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_user_replies = "rank_user_replies"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
@@ -33,7 +33,7 @@ class ConversationMessage(BaseModel):
class Conversation(BaseModel):
"""Represents a conversation between the user and the assistant."""
"""Represents a conversation between the prompter and the assistant."""
messages: list[ConversationMessage] = []
@@ -47,13 +47,13 @@ class TaskRequest(BaseModel):
class TaskAck(BaseModel):
"""The frontend acknowledges that it has received a task and created a post."""
"""The frontend acknowledges that it has received a task and created a message."""
post_id: str
message_id: str
class TaskNAck(BaseModel):
"""The frontend acknowledges that it has received a task but cannot create a post."""
"""The frontend acknowledges that it has received a task but cannot create a message."""
reason: str
@@ -61,7 +61,7 @@ class TaskNAck(BaseModel):
class TaskClose(BaseModel):
"""The frontend asks to mark task as done"""
post_id: str
message_id: str
class Task(BaseModel):
@@ -114,10 +114,10 @@ class ReplyToConversationTask(Task):
conversation: Conversation # the conversation so far
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
class PrompterReplyTask(ReplyToConversationTask, WithHintMixin):
"""A task to prompt the user to submit a reply to the assistant."""
type: Literal["user_reply"] = "user_reply"
type: Literal["prompter_reply"] = "prompter_reply"
class AssistantReplyTask(ReplyToConversationTask):
@@ -141,10 +141,10 @@ class RankConversationRepliesTask(Task):
replies: list[str]
class RankUserRepliesTask(RankConversationRepliesTask):
"""A task to rank a set of user replies to a conversation."""
class RankPrompterRepliesTask(RankConversationRepliesTask):
"""A task to rank a set of prompter replies to a conversation."""
type: Literal["rank_user_replies"] = "rank_user_replies"
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
class RankAssistantRepliesTask(RankConversationRepliesTask):
@@ -165,11 +165,11 @@ AnyTask = Union[
RateSummaryTask,
InitialPromptTask,
ReplyToConversationTask,
UserReplyTask,
PrompterReplyTask,
AssistantReplyTask,
RankInitialPromptsTask,
RankConversationRepliesTask,
RankUserRepliesTask,
RankPrompterRepliesTask,
RankAssistantRepliesTask,
]
@@ -181,35 +181,35 @@ class Interaction(BaseModel):
user: User
class TextReplyToPost(Interaction):
"""A user has replied to a post with text."""
class TextReplyToMessage(Interaction):
"""A user has replied to a message with text."""
type: Literal["text_reply_to_post"] = "text_reply_to_post"
post_id: str
user_post_id: str
type: Literal["text_reply_to_message"] = "text_reply_to_message"
message_id: str
user_message_id: str
text: str
class PostRating(Interaction):
"""A user has rated a post."""
class MessageRating(Interaction):
"""A user has rated a message."""
type: Literal["post_rating"] = "post_rating"
post_id: str
type: Literal["message_rating"] = "message_rating"
message_id: str
rating: int
class PostRanking(Interaction):
"""A user has given a ranking for a post."""
class MessageRanking(Interaction):
"""A user has given a ranking for a message."""
type: Literal["post_ranking"] = "post_ranking"
post_id: str
type: Literal["message_ranking"] = "message_ranking"
message_id: str
ranking: list[int]
AnyInteraction = Union[
TextReplyToPost,
PostRating,
PostRanking,
TextReplyToMessage,
MessageRating,
MessageRanking,
]
@@ -245,12 +245,12 @@ class TextLabels(BaseModel):
text: str
labels: dict[TextLabel, float]
post_id: str | None = None
message_id: str | None = None
@property
def has_post_id(self) -> bool:
"""Whether this TextLabels has a post_id."""
return bool(self.post_id)
def has_message_id(self) -> bool:
"""Whether this TextLabels has a message_id."""
return bool(self.message_id)
# check that each label value is between 0 and 1
@pydantic.validator("labels")
+40 -40
View File
@@ -13,7 +13,7 @@ app = typer.Typer()
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
def _random_post_id():
def _random_message_id():
return str(random.randint(1000, 9999))
@@ -21,7 +21,7 @@ def _render_message(message: dict) -> str:
"""Render a message to the user."""
if message["is_assistant"]:
return f"Assistant: {message['text']}"
return f"User: {message['text']}"
return f"Prompter: {message['text']}"
@app.command()
@@ -43,20 +43,20 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
typer.echo(task["story"])
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
summary = typer.prompt("Enter your summary")
user_post_id = _random_post_id()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_post",
"post_id": post_id,
"user_post_id": user_post_id,
"type": "text_reply_to_message",
"message_id": message_id,
"user_message_id": user_message_id,
"text": summary,
"user": USER,
},
@@ -70,16 +70,16 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}")
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
rating = typer.prompt("Enter your rating", type=int)
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "post_rating",
"post_id": post_id,
"type": "message_rating",
"message_id": message_id,
"rating": rating,
"user": USER,
},
@@ -90,24 +90,24 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
if task["hint"]:
typer.echo(f"Hint: {task['hint']}")
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
prompt = typer.prompt("Enter your prompt")
user_post_id = _random_post_id()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_post",
"post_id": post_id,
"user_post_id": user_post_id,
"type": "text_reply_to_message",
"message_id": message_id,
"user_message_id": user_message_id,
"text": prompt,
"user": USER,
},
)
tasks.append(new_task)
case "user_reply":
case "prompter_reply":
typer.echo("Please provide a reply to the assistant.")
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
@@ -115,17 +115,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
if task["hint"]:
typer.echo(f"Hint: {task['hint']}")
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
reply = typer.prompt("Enter your reply")
user_post_id = _random_post_id()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_post",
"post_id": post_id,
"user_post_id": user_post_id,
"type": "text_reply_to_message",
"message_id": message_id,
"user_message_id": user_message_id,
"text": reply,
"user": USER,
},
@@ -138,17 +138,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
reply = typer.prompt("Enter your reply")
user_post_id = _random_post_id()
user_message_id = _random_message_id()
# send interaction
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "text_reply_to_post",
"post_id": post_id,
"user_post_id": user_post_id,
"type": "text_reply_to_message",
"message_id": message_id,
"user_message_id": user_message_id,
"text": reply,
"user": USER,
},
@@ -160,8 +160,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
for idx, prompt in enumerate(task["prompts"], start=1):
typer.echo(f"{idx}: {prompt}")
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
ranking_str = typer.prompt("Enter the prompt numbers in order of preference, separated by commas")
ranking = [int(x) - 1 for x in ranking_str.split(",")]
@@ -170,15 +170,15 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "post_ranking",
"post_id": post_id,
"type": "message_ranking",
"message_id": message_id,
"ranking": ranking,
"user": USER,
},
)
tasks.append(new_task)
case "rank_user_replies" | "rank_assistant_replies":
case "rank_prompter_replies" | "rank_assistant_replies":
typer.echo("Here is the conversation so far:")
for message in task["conversation"]["messages"]:
typer.echo(_render_message(message))
@@ -186,8 +186,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
for idx, reply in enumerate(task["replies"], start=1):
typer.echo(f"{idx}: {reply}")
# acknowledge task
post_id = _random_post_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
message_id = _random_message_id()
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
ranking = [int(x) - 1 for x in ranking_str.split(",")]
@@ -196,8 +196,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
new_task = _post(
"/api/v1/tasks/interaction",
{
"type": "post_ranking",
"post_id": post_id,
"type": "message_ranking",
"message_id": message_id,
"ranking": ranking,
"user": USER,
},
+4
View File
@@ -37,3 +37,7 @@ next-env.d.ts
# Vim files
*.swp
# cypress
/cypress-visual-screenshots/diff
/cypress-visual-screenshots/comparison
+15
View File
@@ -102,6 +102,21 @@ All static images, fonts, svgs, etc are stored in `public/`.
We're not really using CSS styles. `styles/` can be ignored.
## Testing the UI
Cypress is used for end-to-end (e2e) and component testing and is configured in `./cypress.config.ts`. The `./cypress` folder is used for supporting configuration files etc.
- Store e2e tests in the `./cypress/e2e` folder.
- Store component tests adjacent to the component being tested. If you want to wriite a test for `./src/components/Layout.tsx` then store the test file at `./src/components/Layout.cy.tsx`.
A few npm scripts are available for convenience:
- `npm run cypress`: Useful for development, it opens Cypress and allows you to explore, run and debug tests. It assumes you have the NextJS site running at `localhost:3000`.
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before sending a PR or to run in CI pipelines.
- `npm run cypress:image-baseline`: If you have tests failing because of visual changes that was expected, this command will update the baseline images stored in `./cypress-visual-screenshots/baseline` with those from the adjacent comparison folder. More can be found in the [docs of `uktrade/cypress-image-diff`](https://github.com/uktrade/cypress-image-diff/blob/main/docs/CLI.md#update-all-baseline-images-for-failing-tests).
Read more in the [./cypress README](cypress/).
## Best Practices
When writing code for the website, we have a few best practices:
+25
View File
@@ -0,0 +1,25 @@
import { defineConfig } from "cypress";
import getCompareSnapshotsPlugin from "cypress-image-diff-js/dist/plugin";
export default defineConfig({
e2e: {
baseUrl: "http://localhost:3000",
setupNodeEvents(on, config) {
// implement node event listeners here
getCompareSnapshotsPlugin(on, config);
},
},
component: {
devServer: {
framework: "next",
bundler: "webpack",
viewportWidth: 500,
viewportHeight: 500,
},
setupNodeEvents(on, config) {
// implement node event listeners here
getCompareSnapshotsPlugin(on, config);
},
},
});
+62
View File
@@ -0,0 +1,62 @@
# Component and e2e testing with Cypress
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end testing. Below there's a few examples for the context of this site. To learn more, the [Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) has it all.
Don't get scared by the commercial offerings they offer. Their core is open source, the cloud offering is not necesarry at all and can be replaced by CI tooling and [community efforts](https://sorry-cypress.dev/).
# Component testing
To write a new component test, you either create a new `.tsx` adjacent to the component you want to test or you can use the guide presented yo you when running `npm run cypress` which allows you to easily create the skeleton test for an existing component.
If you have a `Button.tsx` component, create a file next to it called `Button.cy.tsx` which could look like this:
```typescript
import React from "react";
import { Button } from "./Button";
describe("<Button />", () => {
it("renders", () => {
// see: https://on.cypress.io/mounting-react
cy.mount(<Button className="border-gray-800 m-5">Test button</Button>);
cy.get("button").compareSnapshot("button-element");
});
});
```
## What's happening here?
First we use `cy.mount` to mount our component under test. Notive how we specify `className` and inner text - this is where we arrange our component with fake data that we could assert on later.
In the example above, we also use `cy.get` to select the rendered `button` element. Cypress has multiple ways to [select elements](https://docs.cypress.io/guides/references/best-practices), `get` is just one of them (and often not recommended).
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the `button` element and compares it to a baseline located in the `./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical pixels between the two, it will fail the test.
# End-to-end (e2e) testing
e2e tests are stored in the `./cypress/e2e` folder and should be named `{page}.cy.ts` and located in a relative folder structure that mirrors the page under test.
When running `npm run cypress` and selecting e2e testing, we assume you have the NextJS site running at `localhost:3000`.
An example test from this time of writing, could look as follows:
```typescript
describe("signin flow", () => {
it("redirects to a confirmation page on submit of valid email address", () => {
cy.visit("/auth/signin");
cy.get(".chakra-input").type(`test@example.com{enter}`);
cy.url().should("contain", "/auth/verify");
});
});
export {};
```
## What's happening here?
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point the browser at the desired page. It appends relative paths to the configured `baseUrl` (found in `./cypress.config.ts`).
Cypress will [automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts) almost anything you do, but fail if the default timeout is reached.
Then we get the email input field and type our email address. Notice the `{enter}` keyword, this will cause Cypress to hit the return key which we expect to submit the form.
We then assert that the URL should contain `/auth/verify`. Again the timeout will make sure we are not waiting forever, and the test will fail if we do not manage to get there in a reasonable time.
+10
View File
@@ -0,0 +1,10 @@
describe("signin flow", () => {
it("redirects to a confirmation page on submit of valid email address", () => {
cy.visit("/auth/signin");
cy.get(".chakra-input").type(`test@example.com`);
cy.get(".chakra-stack > .chakra-button").click();
cy.url().should("contain", "/auth/verify");
});
});
export {};
+5
View File
@@ -0,0 +1,5 @@
{
"name": "Using fixtures to represent data",
"email": "hello@cypress.io",
"body": "Fixtures are a great way to mock data for responses to routes"
}
+39
View File
@@ -0,0 +1,39 @@
/// <reference types="cypress" />
// ***********************************************
// This example commands.ts shows you how to
// create various custom commands and overwrite
// existing commands.
//
// For more comprehensive examples of custom
// commands please read more here:
// https://on.cypress.io/custom-commands
// ***********************************************
//
//
// -- This is a parent command --
// Cypress.Commands.add('login', (email, password) => { ... })
//
//
// -- This is a child command --
// Cypress.Commands.add('drag', { prevSubject: 'element'}, (subject, options) => { ... })
//
//
// -- This is a dual command --
// Cypress.Commands.add('dismiss', { prevSubject: 'optional'}, (subject, options) => { ... })
//
//
// -- This will overwrite an existing command --
// Cypress.Commands.overwrite('visit', (originalFn, url, options) => { ... })
//
// declare global {
// namespace Cypress {
// interface Chainable {
// login(email: string, password: string): Chainable<void>
// drag(subject: string, options?: Partial<TypeOptions>): Chainable<Element>
// dismiss(subject: string, options?: Partial<TypeOptions>): Chainable<Element>
// visit(originalFn: CommandOriginalFn, url: string, options: Partial<VisitOptions>): Chainable<Element>
// }
// }
// }
export {};
@@ -0,0 +1,14 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width,initial-scale=1.0" />
<title>Components App</title>
<!-- Used by Next.js to inject CSS. -->
<div id="__next_css__DO_NOT_USE__"></div>
</head>
<body>
<div data-cy-root></div>
</body>
</html>
+45
View File
@@ -0,0 +1,45 @@
// ***********************************************************
// This example support/component.ts is processed and
// loaded automatically before your test files.
//
// This is a great place to put global configuration and
// behavior that modifies Cypress.
//
// You can change the location of this file or turn off
// automatically serving support files with the
// 'supportFile' configuration option.
//
// You can read more here:
// https://on.cypress.io/configuration
// ***********************************************************
// Import commands.js using ES2015 syntax:
import "./commands";
import compareSnapshotCommand from "cypress-image-diff-js/dist/command";
import "../../src/styles/globals.css";
// Alternatively you can use CommonJS syntax:
// require('./commands')
import { mount } from "cypress/react18";
// Augment the Cypress namespace to include type definitions for
// your custom command.
// Alternatively, can be defined in cypress/support/component.d.ts
// with a <reference path="./component" /> at the top of your spec.
declare global {
namespace Cypress {
interface Chainable {
mount: typeof mount;
}
}
}
Cypress.Commands.add("mount", mount);
// Example use:
// cy.mount(<MyComponent />)
compareSnapshotCommand();
export {};
+24
View File
@@ -0,0 +1,24 @@
// ***********************************************************
// This example support/e2e.ts is processed and
// loaded automatically before your test files.
//
// This is a great place to put global configuration and
// behavior that modifies Cypress.
//
// You can change the location of this file or turn off
// automatically serving support files with the
// 'supportFile' configuration option.
//
// You can read more here:
// https://on.cypress.io/configuration
// ***********************************************************
// Import commands.js using ES2015 syntax:
import "./commands";
import compareSnapshotCommand from "cypress-image-diff-js/dist/command";
compareSnapshotCommand();
// Alternatively you can use CommonJS syntax:
// require('./commands')
export {};
+1490 -15487
View File
File diff suppressed because it is too large Load Diff
+6 -1
View File
@@ -9,7 +9,10 @@
"start": "next start",
"lint": "next lint",
"storybook": "start-storybook -p 6006",
"build-storybook": "build-storybook"
"build-storybook": "build-storybook",
"cypress": "cypress open",
"cypress:run": "cypress run",
"cypress:image-baseline": "cypress-image-diff -u"
},
"dependencies": {
"@chakra-ui/react": "^2.4.4",
@@ -56,6 +59,8 @@
"@types/node": "18.11.17",
"@types/react": "18.0.26",
"babel-loader": "^8.3.0",
"cypress": "^12.2.0",
"cypress-image-diff-js": "^1.23.0",
"eslint-plugin-storybook": "^0.6.8",
"@typescript-eslint/eslint-plugin": "^5.47.1",
"prettier": "2.8.1",
+12
View File
@@ -0,0 +1,12 @@
import React from "react";
import { Container } from "./Container";
describe("<Container />", () => {
it("renders", () => {
// see: https://on.cypress.io/mounting-react
const className = "my-class";
const text = "test_container";
cy.mount(<Container className={className}>{text}</Container>);
cy.get(`div.${className}`).should("have.class", className).should("be.visible").should("contain", text);
});
});