mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Add terms of service acceptance date to user table (#1046)
* add tos_acceptance_date column to user * send 451 UNAVAILABLE_FOR_LEGAL_REASONS status * add create user REST endpoint * adapt text-frontend to ToS requirements * set DEBUG_IGNORE_TOS_ACCEPTANCE default to True (temporary change) * update down revision to f60958968ff8
This commit is contained in:
+34
@@ -0,0 +1,34 @@
|
||||
"""add tos_acceptance_date to user
|
||||
|
||||
Revision ID: 55361f323d12
|
||||
Revises: 7b8f0011e0b0
|
||||
Create Date: 2023-02-01 00:22:08.280251
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55361f323d12"
|
||||
down_revision = "f60958968ff8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
|
||||
op.drop_column("user_stats", "streak_days")
|
||||
op.drop_column("user_stats", "streak_last_day_date")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
|
||||
)
|
||||
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
|
||||
op.drop_column("user", "tos_acceptance_date")
|
||||
# ### end Alembic commands ###
|
||||
@@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
ur = UserRepository(db=session, api_client=api_client)
|
||||
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
|
||||
pr = PromptRepository(
|
||||
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
|
||||
@@ -59,6 +59,37 @@ def query_frontend_user(
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol.FrontEndUser)
|
||||
def create_frontend_user(
|
||||
*,
|
||||
create_user: protocol.CreateFrontendUserRequest,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
ur = UserRepository(db, api_client)
|
||||
user = ur.lookup_client_user(create_user, create_missing=True)
|
||||
|
||||
def changed(a, b) -> bool:
|
||||
return a is not None and a != b
|
||||
|
||||
# only call update_user if something changed
|
||||
if (
|
||||
changed(create_user.enabled, user.enabled)
|
||||
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
|
||||
or changed(create_user.notes, user.notes)
|
||||
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
|
||||
):
|
||||
user = ur.update_user(
|
||||
user.id,
|
||||
enabled=create_user.enabled,
|
||||
show_on_leaderboard=create_user.show_on_leaderboard,
|
||||
tos_acceptance=create_user.tos_acceptance,
|
||||
notes=create_user.notes,
|
||||
)
|
||||
|
||||
return user.to_protocol_frontend_user()
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
auth_method: str,
|
||||
|
||||
@@ -191,6 +191,7 @@ def update_user(
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
@@ -198,7 +199,7 @@ def update_user(
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
|
||||
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
|
||||
@@ -158,6 +158,9 @@ class Settings(BaseSettings):
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
|
||||
DEBUG_DATABASE_ECHO: bool = False
|
||||
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
|
||||
True # TODO: set False after ToS acceptance UI was added to web-frontend
|
||||
)
|
||||
|
||||
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
|
||||
# terms of service acceptance date
|
||||
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
def to_protocol_frontend_user(self):
|
||||
return protocol.FrontEndUser(
|
||||
user_id=self.id,
|
||||
@@ -55,4 +58,5 @@ class User(SQLModel, table=True):
|
||||
streak_days=self.streak_days,
|
||||
streak_last_day_date=self.streak_last_day_date,
|
||||
last_activity_date=self.last_activity_date,
|
||||
tos_acceptance_date=self.tos_acceptance_date,
|
||||
)
|
||||
|
||||
@@ -35,7 +35,6 @@ from oasst_shared.utils import unaware_to_utc, utcnow
|
||||
from sqlalchemy.orm import Query
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -77,7 +76,14 @@ class PromptRepository:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
|
||||
if self.user.deleted or not self.user.enabled:
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
|
||||
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
|
||||
raise OasstError(
|
||||
"User has not accepted terms of service.",
|
||||
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
|
||||
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
|
||||
)
|
||||
|
||||
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
@@ -90,7 +96,7 @@ class PromptRepository:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.",
|
||||
OasstErrorCode.MESSAGE_NOT_FOUND,
|
||||
HTTP_404_NOT_FOUND,
|
||||
HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
return message
|
||||
|
||||
@@ -675,7 +681,7 @@ class PromptRepository:
|
||||
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
|
||||
return message
|
||||
|
||||
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
|
||||
@@ -874,7 +880,7 @@ class PromptRepository:
|
||||
|
||||
if api_client_id != self.api_client.id:
|
||||
# Unprivileged api client asks for foreign messages
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
|
||||
|
||||
qry = self.db.query(Message)
|
||||
if user_id:
|
||||
|
||||
@@ -73,7 +73,8 @@ class UserRepository:
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
show_on_leaderboard: Optional[bool] = None,
|
||||
) -> None:
|
||||
tos_acceptance: Optional[bool] = None,
|
||||
) -> User:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
@@ -94,8 +95,11 @@ class UserRepository:
|
||||
user.notes = notes
|
||||
if show_on_leaderboard is not None:
|
||||
user.show_on_leaderboard = show_on_leaderboard
|
||||
if tos_acceptance:
|
||||
user.tos_acceptance_date = utcnow()
|
||||
|
||||
self.db.add(user)
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
@@ -143,8 +147,10 @@ class UserRepository:
|
||||
display_name=display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=auth_method,
|
||||
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
|
||||
)
|
||||
if auth_method == "system":
|
||||
user.show_on_leaderboard = False # don't show system users, e.g. import user
|
||||
user.tos_acceptance_date = utcnow()
|
||||
self.db.add(user)
|
||||
elif display_name and display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
@@ -156,6 +162,10 @@ class UserRepository:
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
if not (client_user.auth_method and client_user.id):
|
||||
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
|
||||
|
||||
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
|
||||
@@ -80,6 +80,7 @@ class OasstErrorCode(IntEnum):
|
||||
USER_NOT_SPECIFIED = 4000
|
||||
USER_DISABLED = 4001
|
||||
USER_NOT_FOUND = 4002
|
||||
USER_HAS_NOT_ACCEPTED_TOS = 4003
|
||||
|
||||
EMOJI_OP_UNSUPPORTED = 5000
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ class FrontEndUser(User):
|
||||
streak_days: Optional[int] = None
|
||||
streak_last_day_date: Optional[datetime] = None
|
||||
last_activity_date: Optional[datetime] = None
|
||||
tos_acceptance_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class PageResult(BaseModel):
|
||||
@@ -499,3 +500,10 @@ class MessageEmojiRequest(BaseModel):
|
||||
user: User
|
||||
op: EmojiOp = EmojiOp.togggle
|
||||
emoji: EmojiCode
|
||||
|
||||
|
||||
class CreateFrontendUserRequest(User):
|
||||
show_on_leaderboard: bool = True
|
||||
enabled: bool = True
|
||||
tos_acceptance: Optional[bool] = None
|
||||
notes: Optional[str] = None
|
||||
|
||||
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request = dict(USER)
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -29,6 +29,16 @@ def _render_message(message: dict) -> str:
|
||||
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
|
||||
"""automates tasks"""
|
||||
|
||||
# make sure dummy user has accepted the terms of service
|
||||
create_user_request = dict(USER)
|
||||
create_user_request["tos_acceptance"] = True
|
||||
response = requests.post(
|
||||
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
|
||||
)
|
||||
response.raise_for_status()
|
||||
user = response.json()
|
||||
typer.echo(f"user: {user}")
|
||||
|
||||
def _post(path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
|
||||
response.raise_for_status()
|
||||
|
||||
Reference in New Issue
Block a user