mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
fix breaking api changes
This commit is contained in:
@@ -35,6 +35,8 @@ We are not going to stop at replicating ChatGPT. We want to build the assistant
|
||||
|
||||
### Slide Decks
|
||||
|
||||
[Vision & Roadmap](https://docs.google.com/presentation/d/1n7IrAOVOqwdYgiYrXc8Sj0He8krn5MVZO_iLkCjTtu0/edit?usp=sharing)
|
||||
|
||||
[Important Data Structures](https://docs.google.com/presentation/d/1iaX_nxasVWlvPiSNs0cllR9L_1neZq0RJxd6MFEalUY/edit?usp=sharing)
|
||||
|
||||
## How can you help?
|
||||
@@ -43,9 +45,11 @@ All open source projects begins with people like you. Open source is the belief
|
||||
|
||||
## I’m in! Now what?
|
||||
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e)
|
||||
[Join the OpenAssistant Contributors Discord Server!](https://ykilcher.com/open-assistant-discord), this is for work coordination.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord)
|
||||
[Join the LAION Discord Server!](https://discord.com/invite/mVcgxMPD7e), it has a dedicated channel and is more public.
|
||||
|
||||
[and / or the YK Discord Server](https://ykilcher.com/discord), also has a dedicated, but not as active, channel.
|
||||
|
||||
[Visit the Notion](https://ykilcher.com/open-assistant)
|
||||
|
||||
@@ -61,6 +65,8 @@ has assigned the issue to you, start working on it.
|
||||
If the issue is currently unclear but you are interested, please post in
|
||||
Discord and someone can help clarify the issue with more detail.
|
||||
|
||||
**Always Welcome:** Documentation markdowns in `docs/`, docstrings, diagrams of the system architecture, and other documentation.
|
||||
|
||||
### Submitting Work
|
||||
|
||||
We're all working on different parts of Open Assistant together. To make
|
||||
|
||||
+339
@@ -0,0 +1,339 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""name changes: person->user, post->message, work_package->task
|
||||
|
||||
Revision ID: abb47e9d145a
|
||||
Revises: 73ce3675c1f5
|
||||
Create Date: 2022-12-30 20:54:49.880568
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abb47e9d145a"
|
||||
down_revision = "73ce3675c1f5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# clear DB
|
||||
op.execute("DELETE FROM journal;")
|
||||
op.execute("DELETE FROM work_package;")
|
||||
op.execute("DELETE FROM post_reaction;")
|
||||
op.execute("DELETE FROM post;")
|
||||
op.execute("DELETE FROM person_stats;")
|
||||
op.execute("DELETE FROM person;")
|
||||
op.execute("DELETE FROM text_labels;")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("auth_method", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_user_username", "user", ["api_client_id", "username", "auth_method"], unique=True)
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("depth", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("children_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("parent_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("role", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("lang", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_message_frontend_message_id", "message", ["api_client_id", "frontend_message_id"], unique=True)
|
||||
op.create_index(op.f("ix_message_message_tree_id"), "message", ["message_tree_id"], unique=False)
|
||||
op.create_index(op.f("ix_message_task_id"), "message", ["task_id"], unique=False)
|
||||
op.create_index(op.f("ix_message_user_id"), "message", ["user_id"], unique=False)
|
||||
op.create_table(
|
||||
"task",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("expiry_date", sa.DateTime(), nullable=True),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("done", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("collective", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("ack", sa.Boolean(), nullable=True),
|
||||
sa.Column("frontend_message_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column("message_tree_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.Column("parent_message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_task_user_id"), "task", ["user_id"], unique=False)
|
||||
op.create_table(
|
||||
"user_stats",
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("modified_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("leader_score", sa.Integer(), nullable=False),
|
||||
sa.Column("reactions", sa.Integer(), nullable=False),
|
||||
sa.Column("messages", sa.Integer(), nullable=False),
|
||||
sa.Column("upvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("downvotes", sa.Integer(), nullable=False),
|
||||
sa.Column("task_reward", sa.Integer(), nullable=False),
|
||||
sa.Column("compare_wins", sa.Integer(), nullable=False),
|
||||
sa.Column("compare_losses", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"message_reaction",
|
||||
sa.Column("task_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("created_date", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column("payload_type", sqlmodel.sql.sqltypes.AutoString(length=200), nullable=False),
|
||||
sa.Column("api_client_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["api_client_id"],
|
||||
["api_client.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["task_id"],
|
||||
["task.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("task_id", "user_id"),
|
||||
)
|
||||
|
||||
op.drop_constraint("text_labels_post_id_fkey", "text_labels", type_="foreignkey")
|
||||
op.drop_constraint("journal_post_id_fkey", "journal", type_="foreignkey")
|
||||
op.drop_constraint("journal_person_id_fkey", "journal", type_="foreignkey")
|
||||
|
||||
op.drop_table("post_reaction")
|
||||
|
||||
op.drop_index("ix_post_frontend_post_id", table_name="post")
|
||||
op.drop_index("ix_post_person_id", table_name="post")
|
||||
op.drop_index("ix_post_thread_id", table_name="post")
|
||||
op.drop_index("ix_post_workpackage_id", table_name="post")
|
||||
op.drop_table("post")
|
||||
|
||||
op.drop_index("ix_work_package_person_id", table_name="work_package")
|
||||
op.drop_table("work_package")
|
||||
op.drop_table("person_stats")
|
||||
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.drop_table("person")
|
||||
|
||||
op.add_column("journal", sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.add_column("journal", sa.Column("message_id", sqlmodel.sql.sqltypes.GUID(), nullable=True))
|
||||
op.drop_index("ix_journal_person_id", table_name="journal")
|
||||
op.create_index(op.f("ix_journal_user_id"), "journal", ["user_id"], unique=False)
|
||||
|
||||
op.create_foreign_key(None, "journal", "user", ["user_id"], ["id"])
|
||||
op.create_foreign_key(None, "journal", "message", ["message_id"], ["id"])
|
||||
op.drop_column("journal", "person_id")
|
||||
op.drop_column("journal", "post_id")
|
||||
op.add_column("text_labels", sa.Column("message_id", postgresql.UUID(as_uuid=True), nullable=True))
|
||||
op.create_foreign_key(None, "text_labels", "message", ["message_id"], ["id"])
|
||||
op.drop_column("text_labels", "post_id")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# clear DB
|
||||
op.execute("DELETE FROM journal;")
|
||||
op.execute("DELETE FROM message_reaction;")
|
||||
op.execute("DELETE FROM task;")
|
||||
op.execute("DELETE FROM message;")
|
||||
op.execute("DELETE FROM user_stats;")
|
||||
op.execute('DELETE FROM "user";')
|
||||
op.execute("DELETE FROM text_labels;")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("text_labels", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint("text_labels_message_id_fkey", "text_labels", type_="foreignkey")
|
||||
|
||||
op.drop_column("text_labels", "message_id")
|
||||
op.add_column("journal", sa.Column("post_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.add_column("journal", sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint("journal_message_id_fkey", "journal", type_="foreignkey")
|
||||
op.drop_constraint("journal_user_id_fkey", "journal", type_="foreignkey")
|
||||
|
||||
op.drop_index(op.f("ix_journal_user_id"), table_name="journal")
|
||||
op.create_index("ix_journal_person_id", "journal", ["person_id"], unique=False)
|
||||
op.drop_column("journal", "message_id")
|
||||
op.drop_column("journal", "user_id")
|
||||
|
||||
op.create_table(
|
||||
"person",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column("username", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.Column("display_name", sa.VARCHAR(length=256), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("auth_method", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="person_api_client_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="person_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"person_stats",
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("leader_score", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"modified_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("reactions", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("posts", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("upvotes", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("downvotes", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("work_reward", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("compare_wins", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("compare_losses", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="person_stats_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("person_id", name="person_stats_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"work_package",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("expiry_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("done", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||
sa.Column("ack", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
sa.Column("frontend_ref_post_id", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("parent_post_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("collective", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="work_package_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="work_package_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="work_package_pkey"),
|
||||
)
|
||||
op.create_index("ix_work_package_person_id", "work_package", ["person_id"], unique=False)
|
||||
op.create_table(
|
||||
"post",
|
||||
sa.Column(
|
||||
"id", postgresql.UUID(), server_default=sa.text("gen_random_uuid()"), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column("parent_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("thread_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("workpackage_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=True),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("role", sa.VARCHAR(length=128), autoincrement=False, nullable=False),
|
||||
sa.Column("frontend_post_id", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True),
|
||||
sa.Column("depth", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("children_count", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("lang", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_person_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("id", name="post_pkey"),
|
||||
)
|
||||
op.create_index("ix_post_workpackage_id", "post", ["workpackage_id"], unique=False)
|
||||
op.create_index("ix_post_thread_id", "post", ["thread_id"], unique=False)
|
||||
op.create_index("ix_post_person_id", "post", ["person_id"], unique=False)
|
||||
op.create_index("ix_post_frontend_post_id", "post", ["api_client_id", "frontend_post_id"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"post_reaction",
|
||||
sa.Column("person_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"created_date",
|
||||
postgresql.TIMESTAMP(),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("payload_type", sa.VARCHAR(length=200), autoincrement=False, nullable=False),
|
||||
sa.Column("payload", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False),
|
||||
sa.Column("api_client_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.Column("work_package_id", postgresql.UUID(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["api_client_id"], ["api_client.id"], name="post_reaction_api_client_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["person_id"], ["person.id"], name="post_reaction_person_id_fkey"),
|
||||
sa.ForeignKeyConstraint(["work_package_id"], ["work_package.id"], name="post_reaction_work_package_id_fkey"),
|
||||
)
|
||||
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=False)
|
||||
op.create_foreign_key("text_labels_post_id_fkey", "text_labels", "post", ["post_id"], ["id"])
|
||||
op.create_foreign_key("journal_person_id_fkey", "journal", "person", ["person_id"], ["id"])
|
||||
op.create_foreign_key("journal_post_id_fkey", "journal", "post", ["post_id"], ["id"])
|
||||
|
||||
op.drop_table("message_reaction")
|
||||
op.drop_table("user_stats")
|
||||
op.drop_index(op.f("ix_task_user_id"), table_name="task")
|
||||
op.drop_table("task")
|
||||
op.drop_index(op.f("ix_message_user_id"), table_name="message")
|
||||
op.drop_index(op.f("ix_message_task_id"), table_name="message")
|
||||
op.drop_index(op.f("ix_message_message_tree_id"), table_name="message")
|
||||
op.drop_index("ix_message_frontend_message_id", table_name="message")
|
||||
op.drop_table("message")
|
||||
op.drop_index("ix_user_username", table_name="user")
|
||||
op.drop_table("user")
|
||||
# ### end Alembic commands ###
|
||||
+61
-60
@@ -67,10 +67,10 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
class DummyPost(pydantic.BaseModel):
|
||||
task_post_id: str
|
||||
user_post_id: str
|
||||
parent_post_id: Optional[str]
|
||||
class DummyMessage(pydantic.BaseModel):
|
||||
task_message_id: str
|
||||
user_message_id: str
|
||||
parent_message_id: Optional[str]
|
||||
text: str
|
||||
role: str
|
||||
|
||||
@@ -81,96 +81,97 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
pr = PromptRepository(db=db, api_client=api_client, user=dummy_user)
|
||||
|
||||
dummy_posts = [
|
||||
DummyPost(
|
||||
task_post_id="de111fa8",
|
||||
user_post_id="6f1d0711",
|
||||
parent_post_id=None,
|
||||
dummy_messages = [
|
||||
DummyMessage(
|
||||
task_message_id="de111fa8",
|
||||
user_message_id="6f1d0711",
|
||||
parent_message_id=None,
|
||||
text="Hi!",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="74c381d4",
|
||||
user_post_id="4a24530b",
|
||||
parent_post_id="6f1d0711",
|
||||
DummyMessage(
|
||||
task_message_id="74c381d4",
|
||||
user_message_id="4a24530b",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hello! How can I help you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="3d5dc440",
|
||||
user_post_id="a8c01c04",
|
||||
parent_post_id="4a24530b",
|
||||
DummyMessage(
|
||||
task_message_id="3d5dc440",
|
||||
user_message_id="a8c01c04",
|
||||
parent_message_id="4a24530b",
|
||||
text="Do you have a recipe for potato soup?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="643716c1",
|
||||
user_post_id="f43a93b7",
|
||||
parent_post_id="4a24530b",
|
||||
DummyMessage(
|
||||
task_message_id="643716c1",
|
||||
user_message_id="f43a93b7",
|
||||
parent_message_id="4a24530b",
|
||||
text="Who were the 8 presidents before George Washington?",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="2e4e1e6",
|
||||
user_post_id="c886920",
|
||||
parent_post_id="6f1d0711",
|
||||
DummyMessage(
|
||||
task_message_id="2e4e1e6",
|
||||
user_message_id="c886920",
|
||||
parent_message_id="6f1d0711",
|
||||
text="Hey buddy! How can I serve you?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="970c437d",
|
||||
user_post_id="cec432cf",
|
||||
parent_post_id=None,
|
||||
DummyMessage(
|
||||
task_message_id="970c437d",
|
||||
user_message_id="cec432cf",
|
||||
parent_message_id=None,
|
||||
text="euirdteunvglfe23908230892309832098 AAAAAAAA",
|
||||
role="user",
|
||||
role="prompter",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="6066118e",
|
||||
user_post_id="4f85f637",
|
||||
parent_post_id="cec432cf",
|
||||
DummyMessage(
|
||||
task_message_id="6066118e",
|
||||
user_message_id="4f85f637",
|
||||
parent_message_id="cec432cf",
|
||||
text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?",
|
||||
role="assistant",
|
||||
),
|
||||
DummyPost(
|
||||
task_post_id="ba87780d",
|
||||
user_post_id="0e276b98",
|
||||
parent_post_id="cec432cf",
|
||||
DummyMessage(
|
||||
task_message_id="ba87780d",
|
||||
user_message_id="0e276b98",
|
||||
parent_message_id="cec432cf",
|
||||
text="I'm unsure how to interpret this. Is it a riddle?",
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
|
||||
for p in dummy_posts:
|
||||
wp = pr.fetch_workpackage_by_postid(p.task_post_id)
|
||||
if wp and not wp.ack:
|
||||
logger.warning("Deleting unacknowledged seed data work package")
|
||||
db.delete(wp)
|
||||
wp = None
|
||||
if not wp:
|
||||
if p.parent_post_id is None:
|
||||
wp = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), thread_id=None, parent_post_id=None
|
||||
for msg in dummy_messages:
|
||||
task = pr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = pr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
print("p.parent_post_id", p.parent_post_id)
|
||||
parent_post = pr.fetch_post_by_frontend_post_id(p.parent_post_id, fail_if_missing=True)
|
||||
wp = pr.store_task(
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
)
|
||||
),
|
||||
thread_id=parent_post.thread_id,
|
||||
parent_post_id=parent_post.id,
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
pr.bind_frontend_post_id(wp.id, p.task_post_id)
|
||||
post = pr.store_text_reply(p.text, p.task_post_id, p.user_post_id)
|
||||
pr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id)
|
||||
|
||||
logger.info(
|
||||
f"Inserted: post_id: {post.id}, payload: {post.payload.payload}, parent_post_id: {post.parent_id}"
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data work_package found: {wp.id}")
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -18,8 +18,8 @@ router = APIRouter()
|
||||
def generate_task(
|
||||
request: protocol_schema.TaskRequest, pr: PromptRepository
|
||||
) -> Tuple[protocol_schema.Task, Optional[UUID], Optional[UUID]]:
|
||||
thread_id = None
|
||||
parent_post_id = None
|
||||
message_tree_id = None
|
||||
parent_message_id = None
|
||||
|
||||
match request.type:
|
||||
case protocol_schema.TaskRequestType.random:
|
||||
@@ -54,38 +54,42 @@ def generate_task(
|
||||
task = protocol_schema.InitialPromptTask(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.user_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
posts = pr.fetch_random_conversation("assistant")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
case protocol_schema.TaskRequestType.prompter_reply:
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.UserReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
task = protocol_schema.PrompterReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.assistant_reply:
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
posts = pr.fetch_random_conversation("user")
|
||||
messages = [
|
||||
protocol_schema.ConversationMessage(text=p.payload.payload.text, is_assistant=(p.role == "assistant"))
|
||||
for p in posts
|
||||
messages = pr.fetch_random_conversation("prompter")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=messages))
|
||||
thread_id = posts[-1].thread_id
|
||||
parent_post_id = posts[-1].id
|
||||
task = protocol_schema.AssistantReplyTask(conversation=protocol_schema.Conversation(messages=task_messages))
|
||||
message_tree_id = messages[-1].message_tree_id
|
||||
parent_message_id = messages[-1].id
|
||||
case protocol_schema.TaskRequestType.rank_initial_prompts:
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
|
||||
posts = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[p.payload.payload.text for p in posts])
|
||||
case protocol_schema.TaskRequestType.rank_user_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="assistant")
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
messages = [
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
@@ -93,18 +97,18 @@ def generate_task(
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankUserRepliesTask(
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=messages,
|
||||
messages=task_messages,
|
||||
),
|
||||
replies=replies,
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.rank_assistant_replies:
|
||||
logger.info("Generating a RankAssistantRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(post_role="user")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="prompter")
|
||||
|
||||
messages = [
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
@@ -113,7 +117,7 @@ def generate_task(
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(messages=messages),
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
replies=replies,
|
||||
)
|
||||
case _:
|
||||
@@ -121,7 +125,7 @@ def generate_task(
|
||||
|
||||
logger.info(f"Generated {task=}.")
|
||||
|
||||
return task, thread_id, parent_post_id
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@@ -138,8 +142,8 @@ def request_task(
|
||||
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, request.user)
|
||||
task, thread_id, parent_post_id = generate_task(request, pr)
|
||||
pr.store_task(task, thread_id, parent_post_id, request.collective)
|
||||
task, message_tree_id, parent_message_id = generate_task(request, pr)
|
||||
pr.store_task(task, message_tree_id, parent_message_id, request.collective)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -150,7 +154,7 @@ def request_task(
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack")
|
||||
def acknowledge_task(
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -166,9 +170,9 @@ def acknowledge_task(
|
||||
try:
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
|
||||
# here we store the post id in the database for the task
|
||||
# here we store the message id in the database for the task
|
||||
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
|
||||
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
|
||||
pr.bind_frontend_message_id(task_id=task_id, frontend_message_id=ack_request.message_id)
|
||||
|
||||
except OasstError:
|
||||
raise
|
||||
@@ -179,7 +183,7 @@ def acknowledge_task(
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack")
|
||||
def acknowledge_task_failure(
|
||||
def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -201,7 +205,7 @@ def acknowledge_task_failure(
|
||||
|
||||
|
||||
@router.post("/interaction")
|
||||
def post_interaction(
|
||||
def tasks_interaction(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
@@ -216,29 +220,31 @@ def post_interaction(
|
||||
pr = PromptRepository(db, api_client, user=interaction.user)
|
||||
|
||||
match type(interaction):
|
||||
case protocol_schema.TextReplyToPost:
|
||||
case protocol_schema.TextReplyToMessage:
|
||||
logger.info(
|
||||
f"Frontend reports text reply to {interaction.post_id=} with {interaction.text=} by {interaction.user=}."
|
||||
f"Frontend reports text reply to {interaction.message_id=} with {interaction.text=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the text reply in the database
|
||||
pr.store_text_reply(
|
||||
text=interaction.text, post_id=interaction.post_id, user_post_id=interaction.user_post_id
|
||||
text=interaction.text,
|
||||
frontend_message_id=interaction.message_id,
|
||||
user_frontend_message_id=interaction.user_message_id,
|
||||
)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRating:
|
||||
case protocol_schema.MessageRating:
|
||||
logger.info(
|
||||
f"Frontend reports rating of {interaction.post_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
f"Frontend reports rating of {interaction.message_id=} with {interaction.rating=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# here we store the rating in the database
|
||||
pr.store_rating(interaction)
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
case protocol_schema.PostRanking:
|
||||
case protocol_schema.MessageRanking:
|
||||
logger.info(
|
||||
f"Frontend reports ranking of {interaction.post_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
f"Frontend reports ranking of {interaction.message_id=} with {interaction.ranking=} by {interaction.user=}."
|
||||
)
|
||||
|
||||
# TODO: check if the ranking is valid
|
||||
@@ -262,5 +268,5 @@ def close_collective_task(
|
||||
):
|
||||
api_client = deps.api_auth(api_key, db)
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
pr.close_task(close_task_request.post_id)
|
||||
pr.close_task(close_task_request.message_id)
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@@ -27,21 +27,21 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_GENERATION_FAILED = 1005
|
||||
|
||||
# 2000-3000: prompt_repository
|
||||
INVALID_POST_ID = 2000
|
||||
POST_NOT_FOUND = 2001
|
||||
INVALID_FRONTEND_MESSAGE_ID = 2000
|
||||
MESSAGE_NOT_FOUND = 2001
|
||||
RATING_OUT_OF_RANGE = 2002
|
||||
INVALID_RANKING_VALUE = 2003
|
||||
INVALID_TASK_TYPE = 2004
|
||||
USER_NOT_SPECIFIED = 2005
|
||||
NO_THREADS_FOUND = 2006
|
||||
NO_MESSAGE_TREE_FOUND = 2006
|
||||
NO_REPLIES_FOUND = 2007
|
||||
WORK_PACKAGE_NOT_FOUND = 2100
|
||||
WORK_PACKAGE_EXPIRED = 2101
|
||||
WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
WORK_PACKAGE_ALREADY_UPDATED = 2103
|
||||
WORK_PACKAGE_NOT_ACK = 2104
|
||||
WORK_PACKAGE_ALREADY_DONE = 2105
|
||||
WORK_PACKAGE_NOT_COLLECTIVE = 2106
|
||||
TASK_NOT_FOUND = 2100
|
||||
TASK_EXPIRED = 2101
|
||||
TASK_PAYLOAD_TYPE_MISMATCH = 2102
|
||||
TASK_ALREADY_UPDATED = 2103
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
|
||||
@@ -3,7 +3,7 @@ import enum
|
||||
from typing import Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Person, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Journal, Task, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
@@ -14,71 +14,71 @@ class JournalEventType(str, enum.Enum):
|
||||
"""A label for a piece of text."""
|
||||
|
||||
user_created = "user_created"
|
||||
text_reply_to_post = "text_reply_to_post"
|
||||
post_rating = "post_rating"
|
||||
post_ranking = "post_ranking"
|
||||
text_reply_to_message = "text_reply_to_message"
|
||||
message_rating = "message_rating"
|
||||
message_ranking = "message_ranking"
|
||||
|
||||
|
||||
@payload_type
|
||||
class JournalEvent(BaseModel):
|
||||
type: str
|
||||
person_id: Optional[UUID]
|
||||
post_id: Optional[UUID]
|
||||
workpackage_id: Optional[UUID]
|
||||
user_id: Optional[UUID]
|
||||
message_id: Optional[UUID]
|
||||
task_id: Optional[UUID]
|
||||
task_type: Optional[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class TextReplyEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.text_reply_to_post] = JournalEventType.text_reply_to_post
|
||||
type: Literal[JournalEventType.text_reply_to_message] = JournalEventType.text_reply_to_message
|
||||
length: int
|
||||
role: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RatingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_rating] = JournalEventType.post_rating
|
||||
type: Literal[JournalEventType.message_rating] = JournalEventType.message_rating
|
||||
rating: int
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingEvent(JournalEvent):
|
||||
type: Literal[JournalEventType.post_ranking] = JournalEventType.post_ranking
|
||||
type: Literal[JournalEventType.message_ranking] = JournalEventType.message_ranking
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
class JournalWriter:
|
||||
def __init__(self, db: Session, api_client: ApiClient, person: Person):
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: User):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = person
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.user = user
|
||||
self.user_id = self.user.id if self.user else None
|
||||
|
||||
def log_text_reply(self, work_package: WorkPackage, post_id: UUID, role: str, length: int) -> Journal:
|
||||
def log_text_reply(self, task: Task, message_id: UUID, role: str, length: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_post,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.text_reply_to_message,
|
||||
payload=TextReplyEvent(role=role, length=length),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_rating(self, work_package: WorkPackage, post_id: UUID, rating: int) -> Journal:
|
||||
def log_rating(self, task: Task, message_id: UUID, rating: int) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_rating,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_rating,
|
||||
payload=RatingEvent(rating=rating),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log_ranking(self, work_package: WorkPackage, post_id: UUID, ranking: list[int]) -> Journal:
|
||||
def log_ranking(self, task: Task, message_id: UUID, ranking: list[int]) -> Journal:
|
||||
return self.log(
|
||||
task_type=work_package.payload_type,
|
||||
event_type=JournalEventType.post_ranking,
|
||||
task_type=task.payload_type,
|
||||
event_type=JournalEventType.message_ranking,
|
||||
payload=RankingEvent(ranking=ranking),
|
||||
workpackage_id=work_package.id,
|
||||
post_id=post_id,
|
||||
task_id=task.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
def log(
|
||||
@@ -87,8 +87,8 @@ class JournalWriter:
|
||||
payload: JournalEvent,
|
||||
task_type: str,
|
||||
event_type: str = None,
|
||||
workpackage_id: Optional[UUID] = None,
|
||||
post_id: Optional[UUID] = None,
|
||||
task_id: Optional[UUID] = None,
|
||||
message_id: Optional[UUID] = None,
|
||||
commit: bool = True,
|
||||
) -> Journal:
|
||||
if event_type is None:
|
||||
@@ -97,22 +97,22 @@ class JournalWriter:
|
||||
else:
|
||||
event_type = type(payload).__name__
|
||||
|
||||
if payload.person_id is None:
|
||||
payload.person_id = self.person_id
|
||||
if payload.post_id is None:
|
||||
payload.post_id = post_id
|
||||
if payload.workpackage_id is None:
|
||||
payload.workpackage_id = workpackage_id
|
||||
if payload.user_id is None:
|
||||
payload.user_id = self.user_id
|
||||
if payload.message_id is None:
|
||||
payload.message_id = message_id
|
||||
if payload.task_id is None:
|
||||
payload.task_id = task_id
|
||||
if payload.task_type is None:
|
||||
payload.task_type = task_type
|
||||
|
||||
entry = Journal(
|
||||
person_id=self.person_id,
|
||||
user_id=self.user_id,
|
||||
api_client_id=self.api_client.id,
|
||||
created_date=utcnow(),
|
||||
event_type=event_type,
|
||||
event_payload=PayloadContainer(payload=payload),
|
||||
post_id=post_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .api_client import ApiClient
|
||||
from .journal import Journal, JournalIntegration
|
||||
from .person import Person
|
||||
from .person_stats import PersonStats
|
||||
from .post import Post
|
||||
from .post_reaction import PostReaction
|
||||
from .message import Message
|
||||
from .message_reaction import MessageReaction
|
||||
from .task import Task
|
||||
from .text_labels import TextLabels
|
||||
from .work_package import WorkPackage
|
||||
from .user import User
|
||||
from .user_stats import UserStats
|
||||
|
||||
__all__ = [
|
||||
"ApiClient",
|
||||
"Person",
|
||||
"PersonStats",
|
||||
"Post",
|
||||
"PostReaction",
|
||||
"WorkPackage",
|
||||
"User",
|
||||
"UserStats",
|
||||
"Message",
|
||||
"MessageReaction",
|
||||
"Task",
|
||||
"TextLabels",
|
||||
"Journal",
|
||||
"JournalIntegration",
|
||||
|
||||
@@ -32,8 +32,8 @@ class InitialPromptPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class UserReplyPayload(TaskPayload):
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
class PrompterReplyPayload(TaskPayload):
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
conversation: protocol_schema.Conversation
|
||||
hint: str | None
|
||||
|
||||
@@ -45,7 +45,7 @@ class AssistantReplyPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class PostPayload(BaseModel):
|
||||
class MessagePayload(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
@@ -56,13 +56,13 @@ class ReactionPayload(BaseModel):
|
||||
|
||||
@payload_type
|
||||
class RatingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
type: Literal["message_rating"] = "message_rating"
|
||||
rating: str
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankingReactionPayload(ReactionPayload):
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
@@ -81,10 +81,10 @@ class RankInitialPromptsPayload(TaskPayload):
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
@payload_type
|
||||
|
||||
@@ -33,8 +33,8 @@ class Journal(SQLModel, table=True):
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
post_id: Optional[UUID] = Field(foreign_key="post.id", nullable=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
event_type: str = Field(nullable=False, max_length=200)
|
||||
|
||||
@@ -10,9 +10,9 @@ from sqlmodel import Field, Index, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class Post(SQLModel, table=True):
|
||||
__tablename__ = "post"
|
||||
__table_args__ = (Index("ix_post_frontend_post_id", "api_client_id", "frontend_post_id", unique=True),)
|
||||
class Message(SQLModel, table=True):
|
||||
__tablename__ = "message"
|
||||
__table_args__ = (Index("ix_message_frontend_message_id", "api_client_id", "frontend_message_id", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
@@ -20,12 +20,12 @@ class Post(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
thread_id: UUID = Field(nullable=False, index=True)
|
||||
workpackage_id: UUID = Field(nullable=True, index=True)
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_post_id: str = Field(max_length=200, nullable=False)
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
+6
-6
@@ -10,14 +10,14 @@ from sqlmodel import Field, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class PostReaction(SQLModel, table=True):
|
||||
__tablename__ = "post_reaction"
|
||||
class MessageReaction(SQLModel, table=True):
|
||||
__tablename__ = "message_reaction"
|
||||
|
||||
work_package_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("work_package.id"), nullable=False, primary_key=True)
|
||||
task_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("task.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
person_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), nullable=False, primary_key=True)
|
||||
user_id: UUID = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, primary_key=True)
|
||||
)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
@@ -11,8 +11,8 @@ from sqlmodel import Field, SQLModel
|
||||
from .payload_column_type import PayloadContainer, payload_column_type
|
||||
|
||||
|
||||
class WorkPackage(SQLModel, table=True):
|
||||
__tablename__ = "work_package"
|
||||
class Task(SQLModel, table=True):
|
||||
__tablename__ = "task"
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
@@ -23,15 +23,15 @@ class WorkPackage(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
person_id: UUID = Field(nullable=True, foreign_key="person.id", index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
ack: Optional[bool] = None
|
||||
done: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
frontend_ref_post_id: Optional[str] = None
|
||||
thread_id: Optional[UUID] = None
|
||||
parent_post_id: Optional[UUID] = None
|
||||
frontend_message_id: Optional[str] = None
|
||||
message_tree_id: Optional[UUID] = None
|
||||
parent_message_id: Optional[UUID] = None
|
||||
collective: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
@property
|
||||
@@ -21,5 +21,7 @@ class TextLabels(SQLModel, table=True):
|
||||
)
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
text: str = Field(nullable=False, max_length=2**16)
|
||||
post_id: Optional[UUID] = Field(sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("post.id"), nullable=True))
|
||||
message_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("message.id"), nullable=True)
|
||||
)
|
||||
labels: dict[str, float] = Field(default={}, sa_column=sa.Column(pg.JSONB), nullable=False)
|
||||
|
||||
@@ -8,9 +8,9 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
|
||||
class Person(SQLModel, table=True):
|
||||
__tablename__ = "person"
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = "user"
|
||||
__table_args__ = (Index("ix_user_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
+8
-8
@@ -8,11 +8,11 @@ import sqlalchemy.dialects.postgresql as pg
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class PersonStats(SQLModel, table=True):
|
||||
__tablename__ = "person_stats"
|
||||
class UserStats(SQLModel, table=True):
|
||||
__tablename__ = "user_stats"
|
||||
|
||||
person_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("person.id"), primary_key=True)
|
||||
user_id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(pg.UUID(as_uuid=True), sa.ForeignKey("user.id"), primary_key=True)
|
||||
)
|
||||
leader_score: int = 0
|
||||
modified_date: Optional[datetime] = Field(
|
||||
@@ -20,9 +20,9 @@ class PersonStats(SQLModel, table=True):
|
||||
)
|
||||
|
||||
reactions: int = 0 # reactions sent by user
|
||||
posts: int = 0 # posts sent by user
|
||||
messages: int = 0 # messages sent by user
|
||||
upvotes: int = 0 # received upvotes (form other users)
|
||||
downvotes: int = 0 # received downvotes (from other users)
|
||||
work_reward: int = 0 # reward for workpackage completions
|
||||
compare_wins: int = 0 # num times user's post won compare tasks
|
||||
compare_losses: int = 0 # num times users's post lost compare tasks
|
||||
task_reward: int = 0 # reward for task completions
|
||||
compare_wins: int = 0 # num times user's message won compare tasks
|
||||
compare_losses: int = 0 # num times users's message lost compare tasks
|
||||
@@ -7,7 +7,7 @@ import oasst_backend.models.db_payload as db_payload
|
||||
from loguru import logger
|
||||
from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.journal_writer import JournalWriter
|
||||
from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage
|
||||
from oasst_backend.models import ApiClient, Message, MessageReaction, Task, TextLabels, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session, func
|
||||
@@ -17,247 +17,245 @@ class PromptRepository:
|
||||
def __init__(self, db: Session, api_client: ApiClient, user: Optional[protocol_schema.User]):
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
self.person = self.lookup_person(user)
|
||||
self.person_id = self.person.id if self.person else None
|
||||
self.journal = JournalWriter(db, api_client, self.person)
|
||||
self.user = self.lookup_user(user)
|
||||
self.user_id = self.user.id if self.user else None
|
||||
self.journal = JournalWriter(db, api_client, self.user)
|
||||
|
||||
def lookup_person(self, user: protocol_schema.User) -> Person:
|
||||
if not user:
|
||||
def lookup_user(self, client_user: protocol_schema.User) -> User:
|
||||
if not client_user:
|
||||
return None
|
||||
person: Person = (
|
||||
self.db.query(Person)
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(
|
||||
Person.api_client_id == self.api_client.id,
|
||||
Person.username == user.id,
|
||||
Person.auth_method == user.auth_method,
|
||||
User.api_client_id == self.api_client.id,
|
||||
User.username == client_user.id,
|
||||
User.auth_method == client_user.auth_method,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if person is None:
|
||||
if user is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(
|
||||
username=user.id,
|
||||
display_name=user.display_name,
|
||||
user = User(
|
||||
username=client_user.id,
|
||||
display_name=client_user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=user.auth_method,
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(person)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
elif user.display_name and user.display_name != person.display_name:
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
person.display_name = user.display_name
|
||||
self.db.add(person)
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return person
|
||||
return user
|
||||
|
||||
def validate_post_id(self, post_id: str) -> None:
|
||||
if not isinstance(post_id, str):
|
||||
raise OasstError(f"post_id must be string, not {type(post_id)}", OasstErrorCode.INVALID_POST_ID)
|
||||
if not post_id:
|
||||
raise OasstError("post_id must not be empty", OasstErrorCode.INVALID_POST_ID)
|
||||
def validate_frontend_message_id(self, message_id: str) -> None:
|
||||
if not isinstance(message_id, str):
|
||||
raise OasstError(
|
||||
f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
|
||||
)
|
||||
if not message_id:
|
||||
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)
|
||||
|
||||
def bind_frontend_post_id(self, task_id: UUID, post_id: str):
|
||||
self.validate_post_id(post_id)
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.frontend_ref_post_id = post_id
|
||||
work_pack.ack = True
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find work package
|
||||
work_pack: WorkPackage = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.id == task_id, WorkPackage.api_client_id == self.api_client.id)
|
||||
.first()
|
||||
)
|
||||
if work_pack is None:
|
||||
raise OasstError(f"WorkPackage for task {task_id} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if work_pack.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if work_pack.done or work_pack.ack is not None:
|
||||
raise OasstError("WorkPackage already updated.", OasstErrorCode.WORK_PACKAGE_ALREADY_UPDATED)
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)
|
||||
|
||||
work_pack.ack = False
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(work_pack)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
def fetch_post_by_frontend_post_id(self, frontend_post_id: str, fail_if_missing: bool = True) -> Post:
|
||||
self.validate_post_id(frontend_post_id)
|
||||
post: Post = (
|
||||
self.db.query(Post)
|
||||
.filter(Post.api_client_id == self.api_client.id, Post.frontend_post_id == frontend_post_id)
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
message: Message = (
|
||||
self.db.query(Message)
|
||||
.filter(Message.api_client_id == self.api_client.id, Message.frontend_message_id == frontend_message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if fail_if_missing and post is None:
|
||||
raise OasstError(f"Post with post_id {frontend_post_id} not found.", OasstErrorCode.POST_NOT_FOUND)
|
||||
return post
|
||||
if fail_if_missing and message is None:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND
|
||||
)
|
||||
return message
|
||||
|
||||
def fetch_workpackage_by_postid(self, post_id: str) -> WorkPackage:
|
||||
self.validate_post_id(post_id)
|
||||
work_pack = (
|
||||
self.db.query(WorkPackage)
|
||||
.filter(WorkPackage.api_client_id == self.api_client.id, WorkPackage.frontend_ref_post_id == post_id)
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
self.validate_frontend_message_id(message_id)
|
||||
task = (
|
||||
self.db.query(Task)
|
||||
.filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
|
||||
.one_or_none()
|
||||
)
|
||||
return work_pack
|
||||
return task
|
||||
|
||||
def store_text_reply(self, text: str, post_id: str, user_post_id: str, role: str = None) -> Post:
|
||||
self.validate_post_id(post_id)
|
||||
self.validate_post_id(user_post_id)
|
||||
def store_text_reply(
|
||||
self, text: str, frontend_message_id: str, user_frontend_message_id: str, role: str = None
|
||||
) -> Message:
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
self.validate_frontend_message_id(user_frontend_message_id)
|
||||
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if wp is None:
|
||||
raise OasstError(f"WorkPackage for {post_id=} not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("WorkPackage already expired.", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not wp.ack:
|
||||
raise OasstError("WorkPackage is not acknowledged.", OasstErrorCode.WORK_PACKAGE_NOT_ACK)
|
||||
if wp.done:
|
||||
raise OasstError("WorkPackage already done.", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if not task.ack:
|
||||
raise OasstError("Task is not acknowledged.", OasstErrorCode.TASK_NOT_ACK)
|
||||
if task.done:
|
||||
raise OasstError("Task already done.", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
# If there's no parent post assume user started new conversation
|
||||
role = "user"
|
||||
# If there's no parent message assume user started new conversation
|
||||
role = "prompter"
|
||||
depth = 0
|
||||
|
||||
if wp.parent_post_id:
|
||||
parent_post = self.fetch_post(wp.parent_post_id)
|
||||
parent_post.children_count += 1
|
||||
self.db.add(parent_post)
|
||||
if task.parent_message_id:
|
||||
parent_message = self.fetch_message(task.parent_message_id)
|
||||
parent_message.children_count += 1
|
||||
self.db.add(parent_message)
|
||||
|
||||
depth = parent_post.depth + 1
|
||||
if parent_post.role == "assistant":
|
||||
role = "user"
|
||||
depth = parent_message.depth + 1
|
||||
if parent_message.role == "assistant":
|
||||
role = "prompter"
|
||||
else:
|
||||
role = "assistant"
|
||||
|
||||
# create reply post
|
||||
new_post_id = uuid4()
|
||||
user_post = self.insert_post(
|
||||
post_id=new_post_id,
|
||||
frontend_post_id=user_post_id,
|
||||
parent_id=wp.parent_post_id,
|
||||
thread_id=wp.thread_id or new_post_id,
|
||||
workpackage_id=wp.id,
|
||||
# create reply message
|
||||
new_message_id = uuid4()
|
||||
user_message = self.insert_message(
|
||||
message_id=new_message_id,
|
||||
frontend_message_id=user_frontend_message_id,
|
||||
parent_id=task.parent_message_id,
|
||||
message_tree_id=task.message_tree_id or new_message_id,
|
||||
task_id=task.id,
|
||||
role=role,
|
||||
payload=db_payload.PostPayload(text=text),
|
||||
payload=db_payload.MessagePayload(text=text),
|
||||
depth=depth,
|
||||
)
|
||||
if not wp.collective:
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(work_package=wp, post_id=new_post_id, role=role, length=len(text))
|
||||
return user_post
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
return user_message
|
||||
|
||||
def store_rating(self, rating: protocol_schema.PostRating) -> PostReaction:
|
||||
post = self.fetch_post_by_frontend_post_id(rating.post_id, fail_if_missing=True)
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
work_package = self.fetch_workpackage_by_postid(rating.post_id)
|
||||
work_payload: db_payload.RateSummaryPayload = work_package.payload.payload
|
||||
if type(work_payload) != db_payload.RateSummaryPayload:
|
||||
task = self.fetch_task_by_frontend_message_id(rating.message_id)
|
||||
task_payload: db_payload.RateSummaryPayload = task.payload.payload
|
||||
if type(task_payload) != db_payload.RateSummaryPayload:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"Task payload type mismatch: {type(task_payload)=} != {db_payload.RateSummaryPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
if rating.rating < work_payload.scale.min or rating.rating > work_payload.scale.max:
|
||||
if rating.rating < task_payload.scale.min or rating.rating > task_payload.scale.max:
|
||||
raise OasstError(
|
||||
f"Invalid rating value: {rating.rating=} not in {work_payload.scale=}",
|
||||
f"Invalid rating value: {rating.rating=} not in {task_payload.scale=}",
|
||||
OasstErrorCode.RATING_OUT_OF_RANGE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RatingReactionPayload(rating=rating.rating)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
reaction = self.insert_reaction(message.id, reaction_payload)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
self.journal.log_rating(work_package, post_id=post.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for work_package {work_package.id}.")
|
||||
self.journal.log_rating(task, message_id=message.id, rating=rating.rating)
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
def store_ranking(self, ranking: protocol_schema.PostRanking) -> PostReaction:
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
if not work_package.collective:
|
||||
work_package.done = True
|
||||
self.db.add(work_package)
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> MessageReaction:
|
||||
# fetch task
|
||||
task = self.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
task_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
task.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
match type(task_payload):
|
||||
|
||||
case db_payload.RankUserRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
case db_payload.RankPrompterRepliesPayload | db_payload.RankAssistantRepliesPayload:
|
||||
# validate ranking
|
||||
num_replies = len(work_payload.replies)
|
||||
num_replies = len(task_payload.replies)
|
||||
if sorted(ranking.ranking) != list(range(num_replies)):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_replies=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(task_payload.prompts))):
|
||||
raise OasstError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=}).",
|
||||
OasstErrorCode.INVALID_RANKING_VALUE,
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
# store reaction to message
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(work_package.id, reaction_payload)
|
||||
# TODO: resolve post_id
|
||||
self.journal.log_ranking(work_package, post_id=None, ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(task.id, reaction_payload)
|
||||
# TODO: resolve message_id
|
||||
self.journal.log_ranking(task, message_id=None, ranking=ranking.ranking)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
logger.info(f"Ranking {ranking.ranking} stored for task {task.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise OasstError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.WORK_PACKAGE_PAYLOAD_TYPE_MISMATCH,
|
||||
f"task payload type mismatch: {type(task_payload)=} != {db_payload.RankConversationRepliesPayload}",
|
||||
OasstErrorCode.TASK_PAYLOAD_TYPE_MISMATCH,
|
||||
)
|
||||
|
||||
def store_task(
|
||||
self,
|
||||
task: protocol_schema.Task,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
payload: db_payload.TaskPayload
|
||||
match type(task):
|
||||
case protocol_schema.SummarizeStoryTask:
|
||||
@@ -271,8 +269,8 @@ class PromptRepository:
|
||||
case protocol_schema.InitialPromptTask:
|
||||
payload = db_payload.InitialPromptPayload(hint=task.hint)
|
||||
|
||||
case protocol_schema.UserReplyTask:
|
||||
payload = db_payload.UserReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
case protocol_schema.PrompterReplyTask:
|
||||
payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)
|
||||
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
@@ -280,8 +278,8 @@ class PromptRepository:
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
case protocol_schema.RankPrompterRepliesTask:
|
||||
payload = db_payload.RankPrompterRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
)
|
||||
|
||||
@@ -293,81 +291,85 @@ class PromptRepository:
|
||||
case _:
|
||||
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)
|
||||
|
||||
wp = self.insert_work_package(
|
||||
payload=payload, id=task.id, thread_id=thread_id, parent_post_id=parent_post_id, collective=collective
|
||||
task = self.insert_task(
|
||||
payload=payload,
|
||||
id=task.id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
assert wp.id == task.id
|
||||
return wp
|
||||
assert task.id == task.id
|
||||
return task
|
||||
|
||||
def insert_work_package(
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
id: UUID = None,
|
||||
thread_id: UUID = None,
|
||||
parent_post_id: UUID = None,
|
||||
message_tree_id: UUID = None,
|
||||
parent_message_id: UUID = None,
|
||||
collective: bool = False,
|
||||
) -> WorkPackage:
|
||||
) -> Task:
|
||||
c = PayloadContainer(payload=payload)
|
||||
wp = WorkPackage(
|
||||
task = Task(
|
||||
id=id,
|
||||
person_id=self.person_id,
|
||||
user_id=self.user_id,
|
||||
payload_type=type(payload).__name__,
|
||||
payload=c,
|
||||
api_client_id=self.api_client.id,
|
||||
thread_id=thread_id,
|
||||
parent_post_id=parent_post_id,
|
||||
message_tree_id=message_tree_id,
|
||||
parent_message_id=parent_message_id,
|
||||
collective=collective,
|
||||
)
|
||||
self.db.add(wp)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(wp)
|
||||
return wp
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def insert_post(
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
post_id: UUID,
|
||||
frontend_post_id: str,
|
||||
message_id: UUID,
|
||||
frontend_message_id: str,
|
||||
parent_id: UUID,
|
||||
thread_id: UUID,
|
||||
workpackage_id: UUID,
|
||||
message_tree_id: UUID,
|
||||
task_id: UUID,
|
||||
role: str,
|
||||
payload: db_payload.PostPayload,
|
||||
payload: db_payload.MessagePayload,
|
||||
payload_type: str = None,
|
||||
depth: int = 0,
|
||||
) -> Post:
|
||||
) -> Message:
|
||||
if payload_type is None:
|
||||
if payload is None:
|
||||
payload_type = "null"
|
||||
else:
|
||||
payload_type = type(payload).__name__
|
||||
|
||||
post = Post(
|
||||
id=post_id,
|
||||
message = Message(
|
||||
id=message_id,
|
||||
parent_id=parent_id,
|
||||
thread_id=thread_id,
|
||||
workpackage_id=workpackage_id,
|
||||
person_id=self.person_id,
|
||||
message_tree_id=message_tree_id,
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
role=role,
|
||||
frontend_post_id=frontend_post_id,
|
||||
frontend_message_id=frontend_message_id,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=payload_type,
|
||||
payload=PayloadContainer(payload=payload),
|
||||
depth=depth,
|
||||
)
|
||||
self.db.add(post)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(post)
|
||||
return post
|
||||
self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def insert_reaction(self, work_package_id: UUID, payload: db_payload.ReactionPayload) -> PostReaction:
|
||||
if self.person_id is None:
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
container = PayloadContainer(payload=payload)
|
||||
reaction = PostReaction(
|
||||
work_package_id=work_package_id,
|
||||
person_id=self.person_id,
|
||||
reaction = MessageReaction(
|
||||
task_id=task_id,
|
||||
user_id=self.user_id,
|
||||
payload=container,
|
||||
api_client_id=self.api_client.id,
|
||||
payload_type=type(payload).__name__,
|
||||
@@ -383,108 +385,110 @@ class PromptRepository:
|
||||
text=text_labels.text,
|
||||
labels=text_labels.labels,
|
||||
)
|
||||
if text_labels.has_post_id:
|
||||
self.fetch_post_by_frontend_post_id(text_labels.post_id, fail_if_missing=True)
|
||||
model.post_id = text_labels.post_id
|
||||
if text_labels.has_message_id:
|
||||
self.fetch_message_by_frontend_message_id(text_labels.message_id, fail_if_missing=True)
|
||||
model.message_id = text_labels.message_id
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model
|
||||
|
||||
def fetch_random_thread(self, require_role: str = None) -> list[Post]:
|
||||
def fetch_random_message_tree(self, require_role: str = None) -> list[Message]:
|
||||
"""
|
||||
Loads all posts of a random thread.
|
||||
Loads all messages of a random message_tree.
|
||||
|
||||
:param require_role: If set loads only thread which has
|
||||
at least one post with given role.
|
||||
:param require_role: If set loads only message_tree which has
|
||||
at least one message with given role.
|
||||
"""
|
||||
distinct_threads = self.db.query(Post.thread_id).distinct(Post.thread_id)
|
||||
distinct_message_trees = self.db.query(Message.message_tree_id).distinct(Message.message_tree_id)
|
||||
if require_role:
|
||||
distinct_threads = distinct_threads.filter(Post.role == require_role)
|
||||
distinct_threads = distinct_threads.subquery()
|
||||
distinct_message_trees = distinct_message_trees.filter(Message.role == require_role)
|
||||
distinct_message_trees = distinct_message_trees.subquery()
|
||||
|
||||
random_thread = self.db.query(distinct_threads).order_by(func.random()).limit(1)
|
||||
thread_posts = self.db.query(Post).filter(Post.thread_id.in_(random_thread)).all()
|
||||
return thread_posts
|
||||
random_message_tree = self.db.query(distinct_message_trees).order_by(func.random()).limit(1)
|
||||
message_tree_messages = self.db.query(Message).filter(Message.message_tree_id.in_(random_message_tree)).all()
|
||||
return message_tree_messages
|
||||
|
||||
def fetch_random_conversation(self, last_post_role: str = None) -> list[Post]:
|
||||
def fetch_random_conversation(self, last_message_role: str = None) -> list[Message]:
|
||||
"""
|
||||
Picks a random linear conversation starting from any root post
|
||||
and ending somewhere in the thread, possibly at the root itself.
|
||||
Picks a random linear conversation starting from any root message
|
||||
and ending somewhere in the message_tree, possibly at the root itself.
|
||||
|
||||
:param last_post_role: If set will form a conversation ending with a post
|
||||
:param last_message_role: If set will form a conversation ending with a message
|
||||
created by this role. Necessary for the tasks like "user_reply" where
|
||||
the user should reply as a human and hence the last message of the conversation
|
||||
needs to have "assistant" role.
|
||||
"""
|
||||
thread_posts = self.fetch_random_thread(last_post_role)
|
||||
if not thread_posts:
|
||||
raise OasstError("No threads found", OasstErrorCode.NO_THREADS_FOUND)
|
||||
if last_post_role:
|
||||
conv_posts = [p for p in thread_posts if p.role == last_post_role]
|
||||
conv_posts = [random.choice(conv_posts)]
|
||||
messages_tree = self.fetch_random_message_tree(last_message_role)
|
||||
if not messages_tree:
|
||||
raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND)
|
||||
if last_message_role:
|
||||
conv_messages = [m for m in messages_tree if m.role == last_message_role]
|
||||
conv_messages = [random.choice(conv_messages)]
|
||||
else:
|
||||
conv_posts = [random.choice(thread_posts)]
|
||||
thread_posts = {p.id: p for p in thread_posts}
|
||||
conv_messages = [random.choice(messages_tree)]
|
||||
messages_tree = {m.id: m for m in messages_tree}
|
||||
|
||||
while True:
|
||||
if not conv_posts[-1].parent_id:
|
||||
if not conv_messages[-1].parent_id:
|
||||
# reached the start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread_posts[conv_posts[-1].parent_id]
|
||||
conv_posts.append(parent_post)
|
||||
parent_message = messages_tree[conv_messages[-1].parent_id]
|
||||
conv_messages.append(parent_message)
|
||||
|
||||
return list(reversed(conv_posts))
|
||||
return list(reversed(conv_messages))
|
||||
|
||||
def fetch_random_initial_prompts(self, size: int = 5):
|
||||
posts = self.db.query(Post).filter(Post.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return posts
|
||||
messages = self.db.query(Message).filter(Message.parent_id.is_(None)).order_by(func.random()).limit(size).all()
|
||||
return messages
|
||||
|
||||
def fetch_thread(self, thread_id: UUID):
|
||||
return self.db.query(Post).filter(Post.thread_id == thread_id).all()
|
||||
def fetch_message_tree(self, message_tree_id: UUID):
|
||||
return self.db.query(Message).filter(Message.message_tree_id == message_tree_id).all()
|
||||
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, post_role: str = None):
|
||||
parent = self.db.query(Post.id).filter(Post.children_count > 1)
|
||||
if post_role:
|
||||
parent = parent.filter(Post.role == post_role)
|
||||
def fetch_multiple_random_replies(self, max_size: int = 5, message_role: str = None):
|
||||
parent = self.db.query(Message.id).filter(Message.children_count > 1)
|
||||
if message_role:
|
||||
parent = parent.filter(Message.role == message_role)
|
||||
|
||||
parent = parent.order_by(func.random()).limit(1)
|
||||
replies = self.db.query(Post).filter(Post.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
replies = (
|
||||
self.db.query(Message).filter(Message.parent_id.in_(parent)).order_by(func.random()).limit(max_size).all()
|
||||
)
|
||||
if not replies:
|
||||
raise OasstError("No replies found", OasstErrorCode.NO_REPLIES_FOUND)
|
||||
|
||||
thread = self.fetch_thread(replies[0].thread_id)
|
||||
thread = {p.id: p for p in thread}
|
||||
thread_posts = [thread[replies[0].parent_id]]
|
||||
message_tree = self.fetch_message_tree(replies[0].message_tree_id)
|
||||
message_tree = {p.id: p for p in message_tree}
|
||||
conversation = [message_tree[replies[0].parent_id]]
|
||||
while True:
|
||||
if not thread_posts[-1].parent_id:
|
||||
if not conversation[-1].parent_id:
|
||||
# reached start of the conversation
|
||||
break
|
||||
|
||||
parent_post = thread[thread_posts[-1].parent_id]
|
||||
thread_posts.append(parent_post)
|
||||
parent_message = message_tree[conversation[-1].parent_id]
|
||||
conversation.append(parent_message)
|
||||
|
||||
thread_posts = reversed(thread_posts)
|
||||
conversation = reversed(conversation)
|
||||
|
||||
return thread_posts, replies
|
||||
return conversation, replies
|
||||
|
||||
def fetch_post(self, post_id: UUID) -> Optional[Post]:
|
||||
return self.db.query(Post).filter(Post.id == post_id).one()
|
||||
def fetch_message(self, message_id: UUID) -> Optional[Message]:
|
||||
return self.db.query(Message).filter(Message.id == message_id).one()
|
||||
|
||||
def close_task(self, post_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_post_id(post_id)
|
||||
wp = self.fetch_workpackage_by_postid(post_id)
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
self.validate_frontend_message_id(frontend_message_id)
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not wp:
|
||||
raise OasstError("Work package not found", OasstErrorCode.WORK_PACKAGE_NOT_FOUND)
|
||||
if wp.expired:
|
||||
raise OasstError("Work package expired", OasstErrorCode.WORK_PACKAGE_EXPIRED)
|
||||
if not allow_personal_tasks and not wp.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.WORK_PACKAGE_NOT_COLLECTIVE)
|
||||
if wp.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.WORK_PACKAGE_ALREADY_DONE)
|
||||
if not task:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
|
||||
if task.done:
|
||||
raise OasstError("Allready closed", OasstErrorCode.TASK_ALREADY_DONE)
|
||||
|
||||
wp.done = True
|
||||
self.db.add(wp)
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@@ -10,16 +10,17 @@ from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
class TaskType(str, enum.Enum):
|
||||
"""Task types."""
|
||||
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
@@ -44,10 +45,10 @@ class OasstApiClient:
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.user_reply: protocol_schema.UserReplyTask,
|
||||
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
@@ -78,7 +79,7 @@ class OasstApiClient:
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||
logger.debug(f"Fetch task response: {resp}")
|
||||
logger.debug(f"RESP {resp}")
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(
|
||||
@@ -88,10 +89,10 @@ class OasstApiClient:
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, post_id: str):
|
||||
async def ack_task(self, task_id: str | UUID, message_id: str):
|
||||
"""Send an ACK for a task to the backend."""
|
||||
logger.debug(f"ACK task {task_id} with post {post_id}")
|
||||
req = protocol_schema.TaskAck(post_id=post_id)
|
||||
logger.debug(f"ACK task {task_id} with post {message_id}")
|
||||
req = protocol_schema.TaskAck(message_id=message_id)
|
||||
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
|
||||
|
||||
async def nack_task(self, task_id: str | UUID, reason: str):
|
||||
|
||||
@@ -23,6 +23,6 @@ class GuildSettings(BaseModel):
|
||||
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
raise ValueError("No settings found for this guild.")
|
||||
return None
|
||||
|
||||
return cls.parse_obj(row)
|
||||
|
||||
@@ -84,9 +84,9 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToPost(
|
||||
post_id=str(msg_id),
|
||||
user_post_id=str(event.message_id),
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
@@ -206,20 +206,20 @@ async def _send_task(
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_user_replies:
|
||||
assert isinstance(task, protocol_schema.RankUserRepliesTask)
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_user_reply_embed(task)
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.user_reply:
|
||||
assert isinstance(task, protocol_schema.UserReplyTask)
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _user_reply_embed(task)
|
||||
embed = _prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
@@ -258,28 +258,29 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
|
||||
# User message input
|
||||
if (
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.user_reply
|
||||
or task.type == TaskRequestType.prompter_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = [int(r) for r in content.split(",")]
|
||||
return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = [int(r) for r in content.split(",")]
|
||||
return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -369,7 +370,7 @@ def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) ->
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_user_reply_embed(task: protocol_schema.RankUserRepliesTask) -> hikari.Embed:
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
@@ -403,7 +404,7 @@ def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask)
|
||||
return embed
|
||||
|
||||
|
||||
def _user_reply_embed(task: protocol_schema.UserReplyTask) -> hikari.Embed:
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
|
||||
@@ -12,10 +12,10 @@ class TaskRequestType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
user_reply = "user_reply"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_user_replies = "rank_user_replies"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class ConversationMessage(BaseModel):
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
"""Represents a conversation between the user and the assistant."""
|
||||
"""Represents a conversation between the prompter and the assistant."""
|
||||
|
||||
messages: list[ConversationMessage] = []
|
||||
|
||||
@@ -47,13 +47,13 @@ class TaskRequest(BaseModel):
|
||||
|
||||
|
||||
class TaskAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task and created a post."""
|
||||
"""The frontend acknowledges that it has received a task and created a message."""
|
||||
|
||||
post_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class TaskNAck(BaseModel):
|
||||
"""The frontend acknowledges that it has received a task but cannot create a post."""
|
||||
"""The frontend acknowledges that it has received a task but cannot create a message."""
|
||||
|
||||
reason: str
|
||||
|
||||
@@ -61,7 +61,7 @@ class TaskNAck(BaseModel):
|
||||
class TaskClose(BaseModel):
|
||||
"""The frontend asks to mark task as done"""
|
||||
|
||||
post_id: str
|
||||
message_id: str
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
@@ -114,10 +114,10 @@ class ReplyToConversationTask(Task):
|
||||
conversation: Conversation # the conversation so far
|
||||
|
||||
|
||||
class UserReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
class PrompterReplyTask(ReplyToConversationTask, WithHintMixin):
|
||||
"""A task to prompt the user to submit a reply to the assistant."""
|
||||
|
||||
type: Literal["user_reply"] = "user_reply"
|
||||
type: Literal["prompter_reply"] = "prompter_reply"
|
||||
|
||||
|
||||
class AssistantReplyTask(ReplyToConversationTask):
|
||||
@@ -141,10 +141,10 @@ class RankConversationRepliesTask(Task):
|
||||
replies: list[str]
|
||||
|
||||
|
||||
class RankUserRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
class RankPrompterRepliesTask(RankConversationRepliesTask):
|
||||
"""A task to rank a set of prompter replies to a conversation."""
|
||||
|
||||
type: Literal["rank_user_replies"] = "rank_user_replies"
|
||||
type: Literal["rank_prompter_replies"] = "rank_prompter_replies"
|
||||
|
||||
|
||||
class RankAssistantRepliesTask(RankConversationRepliesTask):
|
||||
@@ -165,11 +165,11 @@ AnyTask = Union[
|
||||
RateSummaryTask,
|
||||
InitialPromptTask,
|
||||
ReplyToConversationTask,
|
||||
UserReplyTask,
|
||||
PrompterReplyTask,
|
||||
AssistantReplyTask,
|
||||
RankInitialPromptsTask,
|
||||
RankConversationRepliesTask,
|
||||
RankUserRepliesTask,
|
||||
RankPrompterRepliesTask,
|
||||
RankAssistantRepliesTask,
|
||||
]
|
||||
|
||||
@@ -181,35 +181,35 @@ class Interaction(BaseModel):
|
||||
user: User
|
||||
|
||||
|
||||
class TextReplyToPost(Interaction):
|
||||
"""A user has replied to a post with text."""
|
||||
class TextReplyToMessage(Interaction):
|
||||
"""A user has replied to a message with text."""
|
||||
|
||||
type: Literal["text_reply_to_post"] = "text_reply_to_post"
|
||||
post_id: str
|
||||
user_post_id: str
|
||||
type: Literal["text_reply_to_message"] = "text_reply_to_message"
|
||||
message_id: str
|
||||
user_message_id: str
|
||||
text: str
|
||||
|
||||
|
||||
class PostRating(Interaction):
|
||||
"""A user has rated a post."""
|
||||
class MessageRating(Interaction):
|
||||
"""A user has rated a message."""
|
||||
|
||||
type: Literal["post_rating"] = "post_rating"
|
||||
post_id: str
|
||||
type: Literal["message_rating"] = "message_rating"
|
||||
message_id: str
|
||||
rating: int
|
||||
|
||||
|
||||
class PostRanking(Interaction):
|
||||
"""A user has given a ranking for a post."""
|
||||
class MessageRanking(Interaction):
|
||||
"""A user has given a ranking for a message."""
|
||||
|
||||
type: Literal["post_ranking"] = "post_ranking"
|
||||
post_id: str
|
||||
type: Literal["message_ranking"] = "message_ranking"
|
||||
message_id: str
|
||||
ranking: list[int]
|
||||
|
||||
|
||||
AnyInteraction = Union[
|
||||
TextReplyToPost,
|
||||
PostRating,
|
||||
PostRanking,
|
||||
TextReplyToMessage,
|
||||
MessageRating,
|
||||
MessageRanking,
|
||||
]
|
||||
|
||||
|
||||
@@ -245,12 +245,12 @@ class TextLabels(BaseModel):
|
||||
|
||||
text: str
|
||||
labels: dict[TextLabel, float]
|
||||
post_id: str | None = None
|
||||
message_id: str | None = None
|
||||
|
||||
@property
|
||||
def has_post_id(self) -> bool:
|
||||
"""Whether this TextLabels has a post_id."""
|
||||
return bool(self.post_id)
|
||||
def has_message_id(self) -> bool:
|
||||
"""Whether this TextLabels has a message_id."""
|
||||
return bool(self.message_id)
|
||||
|
||||
# check that each label value is between 0 and 1
|
||||
@pydantic.validator("labels")
|
||||
|
||||
+40
-40
@@ -13,7 +13,7 @@ app = typer.Typer()
|
||||
USER = {"id": "1234", "display_name": "John Doe", "auth_method": "local"}
|
||||
|
||||
|
||||
def _random_post_id():
|
||||
def _random_message_id():
|
||||
return str(random.randint(1000, 9999))
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ def _render_message(message: dict) -> str:
|
||||
"""Render a message to the user."""
|
||||
if message["is_assistant"]:
|
||||
return f"Assistant: {message['text']}"
|
||||
return f"User: {message['text']}"
|
||||
return f"Prompter: {message['text']}"
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -43,20 +43,20 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
typer.echo(task["story"])
|
||||
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
summary = typer.prompt("Enter your summary")
|
||||
|
||||
user_post_id = _random_post_id()
|
||||
user_message_id = _random_message_id()
|
||||
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": post_id,
|
||||
"user_post_id": user_post_id,
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"user_message_id": user_message_id,
|
||||
"text": summary,
|
||||
"user": USER,
|
||||
},
|
||||
@@ -70,16 +70,16 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
typer.echo(f"Rating scale: {task['scale']['min']} - {task['scale']['max']}")
|
||||
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
rating = typer.prompt("Enter your rating", type=int)
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "post_rating",
|
||||
"post_id": post_id,
|
||||
"type": "message_rating",
|
||||
"message_id": message_id,
|
||||
"rating": rating,
|
||||
"user": USER,
|
||||
},
|
||||
@@ -90,24 +90,24 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
prompt = typer.prompt("Enter your prompt")
|
||||
user_post_id = _random_post_id()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": post_id,
|
||||
"user_post_id": user_post_id,
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"user_message_id": user_message_id,
|
||||
"text": prompt,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "user_reply":
|
||||
case "prompter_reply":
|
||||
typer.echo("Please provide a reply to the assistant.")
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
@@ -115,17 +115,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
if task["hint"]:
|
||||
typer.echo(f"Hint: {task['hint']}")
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
reply = typer.prompt("Enter your reply")
|
||||
user_post_id = _random_post_id()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": post_id,
|
||||
"user_post_id": user_post_id,
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"user_message_id": user_message_id,
|
||||
"text": reply,
|
||||
"user": USER,
|
||||
},
|
||||
@@ -138,17 +138,17 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
reply = typer.prompt("Enter your reply")
|
||||
user_post_id = _random_post_id()
|
||||
user_message_id = _random_message_id()
|
||||
# send interaction
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "text_reply_to_post",
|
||||
"post_id": post_id,
|
||||
"user_post_id": user_post_id,
|
||||
"type": "text_reply_to_message",
|
||||
"message_id": message_id,
|
||||
"user_message_id": user_message_id,
|
||||
"text": reply,
|
||||
"user": USER,
|
||||
},
|
||||
@@ -160,8 +160,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
for idx, prompt in enumerate(task["prompts"], start=1):
|
||||
typer.echo(f"{idx}: {prompt}")
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
ranking_str = typer.prompt("Enter the prompt numbers in order of preference, separated by commas")
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
@@ -170,15 +170,15 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "post_ranking",
|
||||
"post_id": post_id,
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
)
|
||||
tasks.append(new_task)
|
||||
|
||||
case "rank_user_replies" | "rank_assistant_replies":
|
||||
case "rank_prompter_replies" | "rank_assistant_replies":
|
||||
typer.echo("Here is the conversation so far:")
|
||||
for message in task["conversation"]["messages"]:
|
||||
typer.echo(_render_message(message))
|
||||
@@ -186,8 +186,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
for idx, reply in enumerate(task["replies"], start=1):
|
||||
typer.echo(f"{idx}: {reply}")
|
||||
# acknowledge task
|
||||
post_id = _random_post_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"post_id": post_id})
|
||||
message_id = _random_message_id()
|
||||
_post(f"/api/v1/tasks/{task['id']}/ack", {"message_id": message_id})
|
||||
|
||||
ranking_str = typer.prompt("Enter the reply numbers in order of preference, separated by commas")
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
@@ -196,8 +196,8 @@ def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "DUMMY_KEY")
|
||||
new_task = _post(
|
||||
"/api/v1/tasks/interaction",
|
||||
{
|
||||
"type": "post_ranking",
|
||||
"post_id": post_id,
|
||||
"type": "message_ranking",
|
||||
"message_id": message_id,
|
||||
"ranking": ranking,
|
||||
"user": USER,
|
||||
},
|
||||
|
||||
@@ -37,3 +37,7 @@ next-env.d.ts
|
||||
|
||||
# Vim files
|
||||
*.swp
|
||||
|
||||
# cypress
|
||||
/cypress-visual-screenshots/diff
|
||||
/cypress-visual-screenshots/comparison
|
||||
|
||||
@@ -102,6 +102,21 @@ All static images, fonts, svgs, etc are stored in `public/`.
|
||||
|
||||
We're not really using CSS styles. `styles/` can be ignored.
|
||||
|
||||
## Testing the UI
|
||||
|
||||
Cypress is used for end-to-end (e2e) and component testing and is configured in `./cypress.config.ts`. The `./cypress` folder is used for supporting configuration files etc.
|
||||
|
||||
- Store e2e tests in the `./cypress/e2e` folder.
|
||||
- Store component tests adjacent to the component being tested. If you want to wriite a test for `./src/components/Layout.tsx` then store the test file at `./src/components/Layout.cy.tsx`.
|
||||
|
||||
A few npm scripts are available for convenience:
|
||||
|
||||
- `npm run cypress`: Useful for development, it opens Cypress and allows you to explore, run and debug tests. It assumes you have the NextJS site running at `localhost:3000`.
|
||||
- `npm run cypress:run`: Runs all tests. Useful for a quick sanity check before sending a PR or to run in CI pipelines.
|
||||
- `npm run cypress:image-baseline`: If you have tests failing because of visual changes that was expected, this command will update the baseline images stored in `./cypress-visual-screenshots/baseline` with those from the adjacent comparison folder. More can be found in the [docs of `uktrade/cypress-image-diff`](https://github.com/uktrade/cypress-image-diff/blob/main/docs/CLI.md#update-all-baseline-images-for-failing-tests).
|
||||
|
||||
Read more in the [./cypress README](cypress/).
|
||||
|
||||
## Best Practices
|
||||
|
||||
When writing code for the website, we have a few best practices:
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
import { defineConfig } from "cypress";
|
||||
import getCompareSnapshotsPlugin from "cypress-image-diff-js/dist/plugin";
|
||||
|
||||
export default defineConfig({
|
||||
e2e: {
|
||||
baseUrl: "http://localhost:3000",
|
||||
setupNodeEvents(on, config) {
|
||||
// implement node event listeners here
|
||||
getCompareSnapshotsPlugin(on, config);
|
||||
},
|
||||
},
|
||||
|
||||
component: {
|
||||
devServer: {
|
||||
framework: "next",
|
||||
bundler: "webpack",
|
||||
viewportWidth: 500,
|
||||
viewportHeight: 500,
|
||||
},
|
||||
setupNodeEvents(on, config) {
|
||||
// implement node event listeners here
|
||||
getCompareSnapshotsPlugin(on, config);
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,62 @@
|
||||
# Component and e2e testing with Cypress
|
||||
|
||||
[Cypress](https://www.cypress.io/) is used for both component- and end-to-end testing. Below there's a few examples for the context of this site. To learn more, the [Cypress documentation](https://docs.cypress.io/guides/getting-started/opening-the-app) has it all.
|
||||
|
||||
Don't get scared by the commercial offerings they offer. Their core is open source, the cloud offering is not necesarry at all and can be replaced by CI tooling and [community efforts](https://sorry-cypress.dev/).
|
||||
|
||||
# Component testing
|
||||
|
||||
To write a new component test, you either create a new `.tsx` adjacent to the component you want to test or you can use the guide presented yo you when running `npm run cypress` which allows you to easily create the skeleton test for an existing component.
|
||||
|
||||
If you have a `Button.tsx` component, create a file next to it called `Button.cy.tsx` which could look like this:
|
||||
|
||||
```typescript
|
||||
import React from "react";
|
||||
import { Button } from "./Button";
|
||||
|
||||
describe("<Button />", () => {
|
||||
it("renders", () => {
|
||||
// see: https://on.cypress.io/mounting-react
|
||||
cy.mount(<Button className="border-gray-800 m-5">Test button</Button>);
|
||||
cy.get("button").compareSnapshot("button-element");
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
## What's happening here?
|
||||
|
||||
First we use `cy.mount` to mount our component under test. Notive how we specify `className` and inner text - this is where we arrange our component with fake data that we could assert on later.
|
||||
|
||||
In the example above, we also use `cy.get` to select the rendered `button` element. Cypress has multiple ways to [select elements](https://docs.cypress.io/guides/references/best-practices), `get` is just one of them (and often not recommended).
|
||||
|
||||
At last, we use `captureSnapshot` which is a plugin that snaps a photo of the `button` element and compares it to a baseline located in the `./cypress-visual-screenshots/baseline/` folder. If there's too many unidentical pixels between the two, it will fail the test.
|
||||
|
||||
# End-to-end (e2e) testing
|
||||
|
||||
e2e tests are stored in the `./cypress/e2e` folder and should be named `{page}.cy.ts` and located in a relative folder structure that mirrors the page under test.
|
||||
|
||||
When running `npm run cypress` and selecting e2e testing, we assume you have the NextJS site running at `localhost:3000`.
|
||||
|
||||
An example test from this time of writing, could look as follows:
|
||||
|
||||
```typescript
|
||||
describe("signin flow", () => {
|
||||
it("redirects to a confirmation page on submit of valid email address", () => {
|
||||
cy.visit("/auth/signin");
|
||||
cy.get(".chakra-input").type(`test@example.com{enter}`);
|
||||
cy.url().should("contain", "/auth/verify");
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
```
|
||||
|
||||
## What's happening here?
|
||||
|
||||
First we use [`cy.visit`](https://docs.cypress.io/api/commands/visit) to point the browser at the desired page. It appends relative paths to the configured `baseUrl` (found in `./cypress.config.ts`).
|
||||
|
||||
Cypress will [automatically await](https://docs.cypress.io/guides/core-concepts/introduction-to-cypress#Timeouts) almost anything you do, but fail if the default timeout is reached.
|
||||
|
||||
Then we get the email input field and type our email address. Notice the `{enter}` keyword, this will cause Cypress to hit the return key which we expect to submit the form.
|
||||
|
||||
We then assert that the URL should contain `/auth/verify`. Again the timeout will make sure we are not waiting forever, and the test will fail if we do not manage to get there in a reasonable time.
|
||||
@@ -0,0 +1,10 @@
|
||||
describe("signin flow", () => {
|
||||
it("redirects to a confirmation page on submit of valid email address", () => {
|
||||
cy.visit("/auth/signin");
|
||||
cy.get(".chakra-input").type(`test@example.com`);
|
||||
cy.get(".chakra-stack > .chakra-button").click();
|
||||
cy.url().should("contain", "/auth/verify");
|
||||
});
|
||||
});
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Using fixtures to represent data",
|
||||
"email": "hello@cypress.io",
|
||||
"body": "Fixtures are a great way to mock data for responses to routes"
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
/// <reference types="cypress" />
|
||||
// ***********************************************
|
||||
// This example commands.ts shows you how to
|
||||
// create various custom commands and overwrite
|
||||
// existing commands.
|
||||
//
|
||||
// For more comprehensive examples of custom
|
||||
// commands please read more here:
|
||||
// https://on.cypress.io/custom-commands
|
||||
// ***********************************************
|
||||
//
|
||||
//
|
||||
// -- This is a parent command --
|
||||
// Cypress.Commands.add('login', (email, password) => { ... })
|
||||
//
|
||||
//
|
||||
// -- This is a child command --
|
||||
// Cypress.Commands.add('drag', { prevSubject: 'element'}, (subject, options) => { ... })
|
||||
//
|
||||
//
|
||||
// -- This is a dual command --
|
||||
// Cypress.Commands.add('dismiss', { prevSubject: 'optional'}, (subject, options) => { ... })
|
||||
//
|
||||
//
|
||||
// -- This will overwrite an existing command --
|
||||
// Cypress.Commands.overwrite('visit', (originalFn, url, options) => { ... })
|
||||
//
|
||||
// declare global {
|
||||
// namespace Cypress {
|
||||
// interface Chainable {
|
||||
// login(email: string, password: string): Chainable<void>
|
||||
// drag(subject: string, options?: Partial<TypeOptions>): Chainable<Element>
|
||||
// dismiss(subject: string, options?: Partial<TypeOptions>): Chainable<Element>
|
||||
// visit(originalFn: CommandOriginalFn, url: string, options: Partial<VisitOptions>): Chainable<Element>
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,14 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0" />
|
||||
<title>Components App</title>
|
||||
<!-- Used by Next.js to inject CSS. -->
|
||||
<div id="__next_css__DO_NOT_USE__"></div>
|
||||
</head>
|
||||
<body>
|
||||
<div data-cy-root></div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -0,0 +1,45 @@
|
||||
// ***********************************************************
|
||||
// This example support/component.ts is processed and
|
||||
// loaded automatically before your test files.
|
||||
//
|
||||
// This is a great place to put global configuration and
|
||||
// behavior that modifies Cypress.
|
||||
//
|
||||
// You can change the location of this file or turn off
|
||||
// automatically serving support files with the
|
||||
// 'supportFile' configuration option.
|
||||
//
|
||||
// You can read more here:
|
||||
// https://on.cypress.io/configuration
|
||||
// ***********************************************************
|
||||
|
||||
// Import commands.js using ES2015 syntax:
|
||||
import "./commands";
|
||||
import compareSnapshotCommand from "cypress-image-diff-js/dist/command";
|
||||
import "../../src/styles/globals.css";
|
||||
|
||||
// Alternatively you can use CommonJS syntax:
|
||||
// require('./commands')
|
||||
|
||||
import { mount } from "cypress/react18";
|
||||
|
||||
// Augment the Cypress namespace to include type definitions for
|
||||
// your custom command.
|
||||
// Alternatively, can be defined in cypress/support/component.d.ts
|
||||
// with a <reference path="./component" /> at the top of your spec.
|
||||
declare global {
|
||||
namespace Cypress {
|
||||
interface Chainable {
|
||||
mount: typeof mount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Cypress.Commands.add("mount", mount);
|
||||
|
||||
// Example use:
|
||||
// cy.mount(<MyComponent />)
|
||||
|
||||
compareSnapshotCommand();
|
||||
|
||||
export {};
|
||||
@@ -0,0 +1,24 @@
|
||||
// ***********************************************************
|
||||
// This example support/e2e.ts is processed and
|
||||
// loaded automatically before your test files.
|
||||
//
|
||||
// This is a great place to put global configuration and
|
||||
// behavior that modifies Cypress.
|
||||
//
|
||||
// You can change the location of this file or turn off
|
||||
// automatically serving support files with the
|
||||
// 'supportFile' configuration option.
|
||||
//
|
||||
// You can read more here:
|
||||
// https://on.cypress.io/configuration
|
||||
// ***********************************************************
|
||||
|
||||
// Import commands.js using ES2015 syntax:
|
||||
import "./commands";
|
||||
import compareSnapshotCommand from "cypress-image-diff-js/dist/command";
|
||||
compareSnapshotCommand();
|
||||
|
||||
// Alternatively you can use CommonJS syntax:
|
||||
// require('./commands')
|
||||
|
||||
export {};
|
||||
Generated
+1490
-15487
File diff suppressed because it is too large
Load Diff
@@ -9,7 +9,10 @@
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"storybook": "start-storybook -p 6006",
|
||||
"build-storybook": "build-storybook"
|
||||
"build-storybook": "build-storybook",
|
||||
"cypress": "cypress open",
|
||||
"cypress:run": "cypress run",
|
||||
"cypress:image-baseline": "cypress-image-diff -u"
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react": "^2.4.4",
|
||||
@@ -56,6 +59,8 @@
|
||||
"@types/node": "18.11.17",
|
||||
"@types/react": "18.0.26",
|
||||
"babel-loader": "^8.3.0",
|
||||
"cypress": "^12.2.0",
|
||||
"cypress-image-diff-js": "^1.23.0",
|
||||
"eslint-plugin-storybook": "^0.6.8",
|
||||
"@typescript-eslint/eslint-plugin": "^5.47.1",
|
||||
"prettier": "2.8.1",
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import React from "react";
|
||||
import { Container } from "./Container";
|
||||
|
||||
describe("<Container />", () => {
|
||||
it("renders", () => {
|
||||
// see: https://on.cypress.io/mounting-react
|
||||
const className = "my-class";
|
||||
const text = "test_container";
|
||||
cy.mount(<Container className={className}>{text}</Container>);
|
||||
cy.get(`div.${className}`).should("have.class", className).should("be.visible").should("contain", text);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user