From bbf038677cd9fe247db15d6758dd687ce9db1abb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Wed, 1 Feb 2023 23:53:21 +0100 Subject: [PATCH] Add terms of service acceptance date to user table (#1046) * add tos_acceptance_date column to user * send 451 UNAVAILABLE_FOR_LEGAL_REASONS status * add create user REST endpoint * adapt text-frontend to ToS requirements * set DEBUG_IGNORE_TOS_ACCEPTANCE default to True (temporary change) * update down revision to f60958968ff8 --- ...f323d12_add_tos_acceptance_date_to_user.py | 34 +++++++++++++++++++ backend/main.py | 1 + .../oasst_backend/api/v1/frontend_users.py | 31 +++++++++++++++++ backend/oasst_backend/api/v1/users.py | 3 +- backend/oasst_backend/config.py | 3 ++ backend/oasst_backend/models/user.py | 4 +++ backend/oasst_backend/prompt_repository.py | 16 ++++++--- backend/oasst_backend/user_repository.py | 14 ++++++-- .../exceptions/oasst_api_error.py | 1 + oasst-shared/oasst_shared/schemas/protocol.py | 8 +++++ text-frontend/__main__.py | 10 ++++++ text-frontend/auto_main.py | 10 ++++++ 12 files changed, 127 insertions(+), 8 deletions(-) create mode 100644 backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py diff --git a/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py b/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py new file mode 100644 index 00000000..bca17b4f --- /dev/null +++ b/backend/alembic/versions/2023_02_01_0022-55361f323d12_add_tos_acceptance_date_to_user.py @@ -0,0 +1,34 @@ +"""add tos_acceptance_date to user + +Revision ID: 55361f323d12 +Revises: 7b8f0011e0b0 +Create Date: 2023-02-01 00:22:08.280251 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "55361f323d12" +down_revision = "f60958968ff8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True)) + op.drop_column("user_stats", "streak_days") + op.drop_column("user_stats", "streak_last_day_date") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True) + ) + op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True)) + op.drop_column("user", "tos_acceptance_date") + # ### end Alembic commands ### diff --git a/backend/main.py b/backend/main.py index 07d0b45b..8e30b78e 100644 --- a/backend/main.py +++ b/backend/main.py @@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA: ur = UserRepository(db=session, api_client=api_client) tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur) + ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True) pr = PromptRepository( db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr ) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index a4ca6380..114a3a9c 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -59,6 +59,37 @@ def query_frontend_user( return user.to_protocol_frontend_user() +@router.post("/", response_model=protocol.FrontEndUser) +def create_frontend_user( + *, + create_user: protocol.CreateFrontendUserRequest, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + ur = UserRepository(db, api_client) + user = ur.lookup_client_user(create_user, create_missing=True) + + def changed(a, b) -> bool: + return a is not None and a != b + + # only call update_user if something changed + if ( + changed(create_user.enabled, user.enabled) + or changed(create_user.show_on_leaderboard, user.show_on_leaderboard) + or changed(create_user.notes, user.notes) + or (create_user.tos_acceptance and user.tos_acceptance_date is None) + ): + user = ur.update_user( + user.id, + enabled=create_user.enabled, + show_on_leaderboard=create_user.show_on_leaderboard, + tos_acceptance=create_user.tos_acceptance, + notes=create_user.notes, + ) + + return user.to_protocol_frontend_user() + + @router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message]) def query_frontend_user_messages( auth_method: str, diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 2ced40c1..b3604c3f 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -191,6 +191,7 @@ def update_user( enabled: Optional[bool] = None, notes: Optional[str] = None, show_on_leaderboard: Optional[bool] = None, + tos_acceptance: Optional[bool] = None, db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), ): @@ -198,7 +199,7 @@ def update_user( Update a user by global user ID. Only trusted clients can update users. """ ur = UserRepository(db, api_client) - ur.update_user(user_id, enabled, notes, show_on_leaderboard) + ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance) @router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 17935021..f40f5637 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -158,6 +158,9 @@ class Settings(BaseSettings): DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False DEBUG_SKIP_TOXICITY_CALCULATION: bool = False DEBUG_DATABASE_ECHO: bool = False + DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS + True # TODO: set False after ToS acceptance UI was added to web-frontend + ) DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120 diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 3d3bd6a9..6c73089c 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -41,6 +41,9 @@ class User(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp()) ) + # terms of service acceptance date + tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True)) + def to_protocol_frontend_user(self): return protocol.FrontEndUser( user_id=self.id, @@ -55,4 +58,5 @@ class User(SQLModel, table=True): streak_days=self.streak_days, streak_last_day_date=self.streak_last_day_date, last_activity_date=self.last_activity_date, + tos_acceptance_date=self.tos_acceptance_date, ) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index e889e73b..5d46aa2d 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -35,7 +35,6 @@ from oasst_shared.utils import unaware_to_utc, utcnow from sqlalchemy.orm import Query from sqlalchemy.orm.attributes import flag_modified from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update -from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND class PromptRepository: @@ -77,7 +76,14 @@ class PromptRepository: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) if self.user.deleted or not self.user.enabled: - raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED) + raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE) + + if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE: + raise OasstError( + "User has not accepted terms of service.", + OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS, + HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS, + ) def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message: validate_frontend_message_id(frontend_message_id) @@ -90,7 +96,7 @@ class PromptRepository: raise OasstError( f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND, - HTTP_404_NOT_FOUND, + HTTPStatus.NOT_FOUND, ) return message @@ -675,7 +681,7 @@ class PromptRepository: message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: - raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) + raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND) return message def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]: @@ -874,7 +880,7 @@ class PromptRepository: if api_client_id != self.api_client.id: # Unprivileged api client asks for foreign messages - raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN) qry = self.db.query(Message) if user_id: diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 984964b6..ba6d1a10 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -73,7 +73,8 @@ class UserRepository: enabled: Optional[bool] = None, notes: Optional[str] = None, show_on_leaderboard: Optional[bool] = None, - ) -> None: + tos_acceptance: Optional[bool] = None, + ) -> User: """ Update a user by global user ID to disable or set admin notes. Only trusted clients may update users. @@ -94,8 +95,11 @@ class UserRepository: user.notes = notes if show_on_leaderboard is not None: user.show_on_leaderboard = show_on_leaderboard + if tos_acceptance: + user.tos_acceptance_date = utcnow() self.db.add(user) + return user @managed_tx_method(CommitMode.COMMIT) def mark_user_deleted(self, id: UUID) -> None: @@ -143,8 +147,10 @@ class UserRepository: display_name=display_name, api_client_id=self.api_client.id, auth_method=auth_method, - show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user ) + if auth_method == "system": + user.show_on_leaderboard = False # don't show system users, e.g. import user + user.tos_acceptance_date = utcnow() self.db.add(user) elif display_name and display_name != user.display_name: # we found the user but the display name changed @@ -156,6 +162,10 @@ class UserRepository: def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None: if not client_user: return None + + if not (client_user.auth_method and client_user.id): + raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED) + num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT for i in range(num_retries): try: diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 2c3650a6..a8682a32 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -80,6 +80,7 @@ class OasstErrorCode(IntEnum): USER_NOT_SPECIFIED = 4000 USER_DISABLED = 4001 USER_NOT_FOUND = 4002 + USER_HAS_NOT_ACCEPTED_TOS = 4003 EMOJI_OP_UNSUPPORTED = 5000 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 8929251c..1139286c 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -39,6 +39,7 @@ class FrontEndUser(User): streak_days: Optional[int] = None streak_last_day_date: Optional[datetime] = None last_activity_date: Optional[datetime] = None + tos_acceptance_date: Optional[datetime] = None class PageResult(BaseModel): @@ -499,3 +500,10 @@ class MessageEmojiRequest(BaseModel): user: User op: EmojiOp = EmojiOp.togggle emoji: EmojiCode + + +class CreateFrontendUserRequest(User): + show_on_leaderboard: bool = True + enabled: bool = True + tos_acceptance: Optional[bool] = None + notes: Optional[str] = None diff --git a/text-frontend/__main__.py b/text-frontend/__main__.py index 18c1f124..b3f4d925 100644 --- a/text-frontend/__main__.py +++ b/text-frontend/__main__.py @@ -28,6 +28,16 @@ def _render_message(message: dict) -> str: def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): """Simple REPL frontend.""" + # make sure dummy user has accepted the terms of service + create_user_request = dict(USER) + create_user_request["tos_acceptance"] = True + response = requests.post( + f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} + ) + response.raise_for_status() + user = response.json() + typer.echo(f"user: {user}") + def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status() diff --git a/text-frontend/auto_main.py b/text-frontend/auto_main.py index 485ee1cb..2775d98c 100644 --- a/text-frontend/auto_main.py +++ b/text-frontend/auto_main.py @@ -29,6 +29,16 @@ def _render_message(message: dict) -> str: def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"): """automates tasks""" + # make sure dummy user has accepted the terms of service + create_user_request = dict(USER) + create_user_request["tos_acceptance"] = True + response = requests.post( + f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key} + ) + response.raise_for_status() + user = response.json() + typer.echo(f"user: {user}") + def _post(path: str, json: dict) -> dict: response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key}) response.raise_for_status()