Add terms of service acceptance date to user table (#1046)

* add tos_acceptance_date column to user

* send 451 UNAVAILABLE_FOR_LEGAL_REASONS status

* add create user REST endpoint

* adapt text-frontend to ToS requirements

* set DEBUG_IGNORE_TOS_ACCEPTANCE default to True (temporary change)

* update down revision to f60958968ff8
This commit is contained in:
Andreas Köpf
2023-02-01 23:53:21 +01:00
committed by GitHub
parent e0df9f0b7c
commit bbf038677c
12 changed files with 127 additions and 8 deletions
@@ -0,0 +1,34 @@
"""add tos_acceptance_date to user
Revision ID: 55361f323d12
Revises: 7b8f0011e0b0
Create Date: 2023-02-01 00:22:08.280251
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55361f323d12"
down_revision = "f60958968ff8"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user", sa.Column("tos_acceptance_date", sa.DateTime(timezone=True), nullable=True))
op.drop_column("user_stats", "streak_days")
op.drop_column("user_stats", "streak_last_day_date")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"user_stats", sa.Column("streak_last_day_date", postgresql.TIMESTAMP(), autoincrement=False, nullable=True)
)
op.add_column("user_stats", sa.Column("streak_days", sa.INTEGER(), autoincrement=False, nullable=True))
op.drop_column("user", "tos_acceptance_date")
# ### end Alembic commands ###
+1
View File
@@ -147,6 +147,7 @@ if settings.DEBUG_USE_SEED_DATA:
ur = UserRepository(db=session, api_client=api_client)
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
ur.update_user(tr.user_id, enabled=True, show_on_leaderboard=False, tos_acceptance=True)
pr = PromptRepository(
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
@@ -59,6 +59,37 @@ def query_frontend_user(
return user.to_protocol_frontend_user()
@router.post("/", response_model=protocol.FrontEndUser)
def create_frontend_user(
*,
create_user: protocol.CreateFrontendUserRequest,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
ur = UserRepository(db, api_client)
user = ur.lookup_client_user(create_user, create_missing=True)
def changed(a, b) -> bool:
return a is not None and a != b
# only call update_user if something changed
if (
changed(create_user.enabled, user.enabled)
or changed(create_user.show_on_leaderboard, user.show_on_leaderboard)
or changed(create_user.notes, user.notes)
or (create_user.tos_acceptance and user.tos_acceptance_date is None)
):
user = ur.update_user(
user.id,
enabled=create_user.enabled,
show_on_leaderboard=create_user.show_on_leaderboard,
tos_acceptance=create_user.tos_acceptance,
notes=create_user.notes,
)
return user.to_protocol_frontend_user()
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
auth_method: str,
+2 -1
View File
@@ -191,6 +191,7 @@ def update_user(
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
tos_acceptance: Optional[bool] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
@@ -198,7 +199,7 @@ def update_user(
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes, show_on_leaderboard)
ur.update_user(user_id, enabled, notes, show_on_leaderboard, tos_acceptance)
@router.delete("/{user_id}", status_code=HTTP_204_NO_CONTENT)
+3
View File
@@ -158,6 +158,9 @@ class Settings(BaseSettings):
DEBUG_SKIP_EMBEDDING_COMPUTATION: bool = False
DEBUG_SKIP_TOXICITY_CALCULATION: bool = False
DEBUG_DATABASE_ECHO: bool = False
DEBUG_IGNORE_TOS_ACCEPTANCE: bool = ( # ignore whether users accepted the ToS
True # TODO: set False after ToS acceptance UI was added to web-frontend
)
DUPLICATE_MESSAGE_FILTER_WINDOW_MINUTES: int = 120
+4
View File
@@ -41,6 +41,9 @@ class User(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True, server_default=sa.func.current_timestamp())
)
# terms of service acceptance date
tos_acceptance_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(timezone=True), nullable=True))
def to_protocol_frontend_user(self):
return protocol.FrontEndUser(
user_id=self.id,
@@ -55,4 +58,5 @@ class User(SQLModel, table=True):
streak_days=self.streak_days,
streak_last_day_date=self.streak_last_day_date,
last_activity_date=self.last_activity_date,
tos_acceptance_date=self.tos_acceptance_date,
)
+11 -5
View File
@@ -35,7 +35,6 @@ from oasst_shared.utils import unaware_to_utc, utcnow
from sqlalchemy.orm import Query
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import JSON, Session, and_, func, literal_column, not_, or_, text, update
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
@@ -77,7 +76,14 @@ class PromptRepository:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
if self.user.deleted or not self.user.enabled:
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED)
raise OasstError("User account disabled", OasstErrorCode.USER_DISABLED, HTTPStatus.SERVICE_UNAVAILABLE)
if self.user.tos_acceptance_date is None and not settings.DEBUG_IGNORE_TOS_ACCEPTANCE:
raise OasstError(
"User has not accepted terms of service.",
OasstErrorCode.USER_HAS_NOT_ACCEPTED_TOS,
HTTPStatus.UNAVAILABLE_FOR_LEGAL_REASONS,
)
def fetch_message_by_frontend_message_id(self, frontend_message_id: str, fail_if_missing: bool = True) -> Message:
validate_frontend_message_id(frontend_message_id)
@@ -90,7 +96,7 @@ class PromptRepository:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.",
OasstErrorCode.MESSAGE_NOT_FOUND,
HTTP_404_NOT_FOUND,
HTTPStatus.NOT_FOUND,
)
return message
@@ -675,7 +681,7 @@ class PromptRepository:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTPStatus.NOT_FOUND)
return message
def fetch_non_task_text_labels(self, message_id: UUID, user_id: UUID) -> Optional[TextLabels]:
@@ -874,7 +880,7 @@ class PromptRepository:
if api_client_id != self.api_client.id:
# Unprivileged api client asks for foreign messages
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTPStatus.FORBIDDEN)
qry = self.db.query(Message)
if user_id:
+12 -2
View File
@@ -73,7 +73,8 @@ class UserRepository:
enabled: Optional[bool] = None,
notes: Optional[str] = None,
show_on_leaderboard: Optional[bool] = None,
) -> None:
tos_acceptance: Optional[bool] = None,
) -> User:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
@@ -94,8 +95,11 @@ class UserRepository:
user.notes = notes
if show_on_leaderboard is not None:
user.show_on_leaderboard = show_on_leaderboard
if tos_acceptance:
user.tos_acceptance_date = utcnow()
self.db.add(user)
return user
@managed_tx_method(CommitMode.COMMIT)
def mark_user_deleted(self, id: UUID) -> None:
@@ -143,8 +147,10 @@ class UserRepository:
display_name=display_name,
api_client_id=self.api_client.id,
auth_method=auth_method,
show_on_leaderboard=(auth_method != "system"), # don't show system users, e.g. import user
)
if auth_method == "system":
user.show_on_leaderboard = False # don't show system users, e.g. import user
user.tos_acceptance_date = utcnow()
self.db.add(user)
elif display_name and display_name != user.display_name:
# we found the user but the display name changed
@@ -156,6 +162,10 @@ class UserRepository:
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> User | None:
if not client_user:
return None
if not (client_user.auth_method and client_user.id):
raise OasstError("Auth method or username missing.", OasstErrorCode.AUTH_AND_USERNAME_REQUIRED)
num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT
for i in range(num_retries):
try:
@@ -80,6 +80,7 @@ class OasstErrorCode(IntEnum):
USER_NOT_SPECIFIED = 4000
USER_DISABLED = 4001
USER_NOT_FOUND = 4002
USER_HAS_NOT_ACCEPTED_TOS = 4003
EMOJI_OP_UNSUPPORTED = 5000
@@ -39,6 +39,7 @@ class FrontEndUser(User):
streak_days: Optional[int] = None
streak_last_day_date: Optional[datetime] = None
last_activity_date: Optional[datetime] = None
tos_acceptance_date: Optional[datetime] = None
class PageResult(BaseModel):
@@ -499,3 +500,10 @@ class MessageEmojiRequest(BaseModel):
user: User
op: EmojiOp = EmojiOp.togggle
emoji: EmojiCode
class CreateFrontendUserRequest(User):
show_on_leaderboard: bool = True
enabled: bool = True
tos_acceptance: Optional[bool] = None
notes: Optional[str] = None
+10
View File
@@ -28,6 +28,16 @@ def _render_message(message: dict) -> str:
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""Simple REPL frontend."""
# make sure dummy user has accepted the terms of service
create_user_request = dict(USER)
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()
+10
View File
@@ -29,6 +29,16 @@ def _render_message(message: dict) -> str:
def main(backend_url: str = "http://127.0.0.1:8080", api_key: str = "1234"):
"""automates tasks"""
# make sure dummy user has accepted the terms of service
create_user_request = dict(USER)
create_user_request["tos_acceptance"] = True
response = requests.post(
f"{backend_url}/api/v1/frontend_users/", json=create_user_request, headers={"X-API-Key": api_key}
)
response.raise_for_status()
user = response.json()
typer.echo(f"user: {user}")
def _post(path: str, json: dict) -> dict:
response = requests.post(f"{backend_url}{path}", json=json, headers={"X-API-Key": api_key})
response.raise_for_status()