From 652b7bff04ec6c4b273d8540eb3b3a55da4de1b1 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Thu, 2 Feb 2023 18:33:21 +0000 Subject: [PATCH] 857: Implement backend Discord authentication (#943) * Initial code for backend auth * Remove outdated check * Initial separation of AuthenticatedUser * AuthenticatedUser -> Account * Rework for Account * Initial code for Discord OAuth * Remove now redundant methods * Remove incorrect response model, add requests dep for backend * Create Settings fields for Discord values * Cleanup get account from Discord function * Cleanup * Cleanup * Generate alembic upgrade script * Remove unused error codes * Update alembic script to correct down revision * Use aiohttp over requests * Update alembic script to latest down revision --- ..._02_1817-8c8241d1f973_add_account_table.py | 41 +++++++++++ backend/oasst_backend/api/v1/login.py | 73 +++++++++++++++++++ backend/oasst_backend/auth.py | 37 ++++++++++ backend/oasst_backend/config.py | 5 ++ backend/oasst_backend/models/user.py | 17 +++++ backend/requirements.txt | 1 + .../exceptions/oasst_api_error.py | 2 + oasst-shared/oasst_shared/schemas/protocol.py | 11 +++ 8 files changed, 187 insertions(+) create mode 100644 backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py create mode 100644 backend/oasst_backend/api/v1/login.py create mode 100644 backend/oasst_backend/auth.py diff --git a/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py b/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py new file mode 100644 index 00000000..3ec2708c --- /dev/null +++ b/backend/alembic/versions/2023_02_02_1817-8c8241d1f973_add_account_table.py @@ -0,0 +1,41 @@ +"""Add Account table + +Revision ID: 8c8241d1f973 +Revises: 4d7e0b0ebe84 +Create Date: 2023-01-30 15:10:58.776315 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "8c8241d1f973" +down_revision = "4d7e0b0ebe84" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "account", + sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"]), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index("id", "account", [], unique=True) + op.create_index("provider", "account", ["provider_account_id"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("provider", table_name="account") + op.drop_index("id", table_name="account") + op.drop_table("account") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/login.py b/backend/oasst_backend/api/v1/login.py new file mode 100644 index 00000000..8aab5328 --- /dev/null +++ b/backend/oasst_backend/api/v1/login.py @@ -0,0 +1,73 @@ +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request +from oasst_backend import auth +from oasst_backend.api import deps +from oasst_backend.config import Settings +from oasst_backend.models import Account +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol as protocol_schema +from sqlmodel import Session +from starlette.status import HTTP_401_UNAUTHORIZED + +router = APIRouter() + + +@router.get("/discord") +def login_discord(request: Request): + redirect_uri = f"{get_callback_uri(request)}/discord" + auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify" + raise HTTPException(status_code=302, headers={"location": auth_url}) + + +@router.get("/callback/discord", response_model=protocol_schema.Token) +async def callback_discord( + auth_code: str, + request: Request, + db: Session = Depends(deps.get_db), +): + redirect_uri = f"{get_callback_uri(request)}/discord" + + async with aiohttp.ClientSession(raise_for_status=True) as session: + # Exchange the auth code for a Discord access token + async with session.post( + "https://discord.com/api/oauth2/token", + data={ + "client_id": Settings.AUTH_DISCORD_CLIENT_ID, + "client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": redirect_uri, + "scope": "identify", + }, + ) as token_response: + token_response_json = await token_response.json() + access_token = token_response_json["access_token"] + + # Retrieve user's Discord information using access token + async with session.get( + "https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"} + ) as user_response: + user_response_json = await user_response.json() + discord_id = user_response_json["id"] + + account: Account = auth.get_account_from_discord_id(db, discord_id) + + if not account: + # Discord account is not linked to an OA account + raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED) + + # Discord account is valid and linked to an OA account -> create JWT + access_token = auth.create_access_token(account) + + return protocol_schema.Token(access_token=access_token, token_type="bearer") + + +def get_callback_uri(request: Request): + """ + Gets the URI for the base callback endpoint with no provider name appended. + """ + # This seems ugly, not sure if there is a better way + current_url = str(request.url) + domain = current_url.split("/api/v1/")[0] + redirect_uri = f"{domain}/api/v1/callback" + return redirect_uri diff --git a/backend/oasst_backend/auth.py b/backend/oasst_backend/auth.py new file mode 100644 index 00000000..2c633fa4 --- /dev/null +++ b/backend/oasst_backend/auth.py @@ -0,0 +1,37 @@ +from datetime import datetime, timedelta +from typing import Optional + +from jose import jwt +from oasst_backend.config import Settings +from oasst_backend.models import Account +from sqlmodel import Session + + +def create_access_token(data: dict) -> str: + """ + Create an encoded JSON Web Token (JWT) using the given data. + """ + + expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM) + return encoded_jwt + + +def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]: + """ + Get the Open-Assistant Account associated with the given Discord ID. + """ + + account: Account = ( + db.query(Account) + .filter( + Account.provider == "discord", + Account.provider_account_id == discord_id, + ) + .first() + ) + + return account diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index f40f5637..851f9ace 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -135,6 +135,11 @@ class Settings(BaseSettings): AUTH_LENGTH: int = 32 AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98=" AUTH_COOKIE_NAME: str = "next-auth.session-token" + AUTH_ALGORITHM: str = "HS256" + AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + + AUTH_DISCORD_CLIENT_ID: str = "" + AUTH_DISCORD_CLIENT_SECRET: str = "" POSTGRES_HOST: str = "localhost" POSTGRES_PORT: str = "5432" diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 6c73089c..69fc3e37 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -60,3 +60,20 @@ class User(SQLModel, table=True): last_activity_date=self.last_activity_date, tos_acceptance_date=self.tos_acceptance_date, ) + + +class Account(SQLModel, table=True): + __tablename__ = "account" + __table_args__ = ( + Index("id", unique=True), + Index("provider", "provider_account_id", unique=True), + ) + + id: Optional[UUID] = Field( + sa_column=sa.Column( + pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") + ), + ) + user_id: UUID = Field(foreign_key="user.id") + provider: str = Field(nullable=False, max_length=128, default="email") # discord or email + provider_account_id: str = Field(nullable=False, max_length=128) diff --git a/backend/requirements.txt b/backend/requirements.txt index 4a112bc8..4a0008bb 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,3 +1,4 @@ +aiohttp==3.8.3 alembic==1.8.1 cryptography==39.0.0 fastapi==0.88.0 diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index a8682a32..58c8aadc 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum): SERVER_ERROR0 = 500 SERVER_ERROR1 = 501 + INVALID_AUTHENTICATION = 600 + # 1000-2000: tasks endpoint TASK_INVALID_REQUEST_TYPE = 1000 TASK_ACK_FAILED = 1001 diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index e0dde366..a237b0c9 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -29,6 +29,17 @@ class User(BaseModel): auth_method: Literal["discord", "local", "system"] +class Account(BaseModel): + id: UUID + provider: str + provider_account_id: str + + +class Token(BaseModel): + access_token: str + token_type: str + + class FrontEndUser(User): user_id: UUID enabled: bool