857: Implement backend Discord authentication (#943)

* Initial code for backend auth

* Remove outdated check

* Initial separation of AuthenticatedUser

* AuthenticatedUser -> Account

* Rework for Account

* Initial code for Discord OAuth

* Remove now redundant methods

* Remove incorrect response model, add requests dep for backend

* Create Settings fields for Discord values

* Cleanup get account from Discord function

* Cleanup

* Cleanup

* Generate alembic upgrade script

* Remove unused error codes

* Update alembic script to correct down revision

* Use aiohttp over requests

* Update alembic script to latest down revision
This commit is contained in:
Oliver Stanley
2023-02-02 18:33:21 +00:00
committed by GitHub
parent dfd2c35276
commit 652b7bff04
8 changed files with 187 additions and 0 deletions
@@ -0,0 +1,41 @@
"""Add Account table
Revision ID: 8c8241d1f973
Revises: 4d7e0b0ebe84
Create Date: 2023-01-30 15:10:58.776315
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8c8241d1f973"
down_revision = "4d7e0b0ebe84"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"account",
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("id", "account", [], unique=True)
op.create_index("provider", "account", ["provider_account_id"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("provider", table_name="account")
op.drop_index("id", table_name="account")
op.drop_table("account")
# ### end Alembic commands ###
+73
View File
@@ -0,0 +1,73 @@
import aiohttp
from fastapi import APIRouter, Depends, HTTPException, Request
from oasst_backend import auth
from oasst_backend.api import deps
from oasst_backend.config import Settings
from oasst_backend.models import Account
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_401_UNAUTHORIZED
router = APIRouter()
@router.get("/discord")
def login_discord(request: Request):
redirect_uri = f"{get_callback_uri(request)}/discord"
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
raise HTTPException(status_code=302, headers={"location": auth_url})
@router.get("/callback/discord", response_model=protocol_schema.Token)
async def callback_discord(
auth_code: str,
request: Request,
db: Session = Depends(deps.get_db),
):
redirect_uri = f"{get_callback_uri(request)}/discord"
async with aiohttp.ClientSession(raise_for_status=True) as session:
# Exchange the auth code for a Discord access token
async with session.post(
"https://discord.com/api/oauth2/token",
data={
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
"scope": "identify",
},
) as token_response:
token_response_json = await token_response.json()
access_token = token_response_json["access_token"]
# Retrieve user's Discord information using access token
async with session.get(
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
) as user_response:
user_response_json = await user_response.json()
discord_id = user_response_json["id"]
account: Account = auth.get_account_from_discord_id(db, discord_id)
if not account:
# Discord account is not linked to an OA account
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
# Discord account is valid and linked to an OA account -> create JWT
access_token = auth.create_access_token(account)
return protocol_schema.Token(access_token=access_token, token_type="bearer")
def get_callback_uri(request: Request):
"""
Gets the URI for the base callback endpoint with no provider name appended.
"""
# This seems ugly, not sure if there is a better way
current_url = str(request.url)
domain = current_url.split("/api/v1/")[0]
redirect_uri = f"{domain}/api/v1/callback"
return redirect_uri
+37
View File
@@ -0,0 +1,37 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import jwt
from oasst_backend.config import Settings
from oasst_backend.models import Account
from sqlmodel import Session
def create_access_token(data: dict) -> str:
"""
Create an encoded JSON Web Token (JWT) using the given data.
"""
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
return encoded_jwt
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
"""
Get the Open-Assistant Account associated with the given Discord ID.
"""
account: Account = (
db.query(Account)
.filter(
Account.provider == "discord",
Account.provider_account_id == discord_id,
)
.first()
)
return account
+5
View File
@@ -135,6 +135,11 @@ class Settings(BaseSettings):
AUTH_LENGTH: int = 32
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
AUTH_COOKIE_NAME: str = "next-auth.session-token"
AUTH_ALGORITHM: str = "HS256"
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
AUTH_DISCORD_CLIENT_ID: str = ""
AUTH_DISCORD_CLIENT_SECRET: str = ""
POSTGRES_HOST: str = "localhost"
POSTGRES_PORT: str = "5432"
+17
View File
@@ -60,3 +60,20 @@ class User(SQLModel, table=True):
last_activity_date=self.last_activity_date,
tos_acceptance_date=self.tos_acceptance_date,
)
class Account(SQLModel, table=True):
__tablename__ = "account"
__table_args__ = (
Index("id", unique=True),
Index("provider", "provider_account_id", unique=True),
)
id: Optional[UUID] = Field(
sa_column=sa.Column(
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
),
)
user_id: UUID = Field(foreign_key="user.id")
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
provider_account_id: str = Field(nullable=False, max_length=128)
+1
View File
@@ -1,3 +1,4 @@
aiohttp==3.8.3
alembic==1.8.1
cryptography==39.0.0
fastapi==0.88.0
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
SERVER_ERROR0 = 500
SERVER_ERROR1 = 501
INVALID_AUTHENTICATION = 600
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
TASK_ACK_FAILED = 1001
@@ -29,6 +29,17 @@ class User(BaseModel):
auth_method: Literal["discord", "local", "system"]
class Account(BaseModel):
id: UUID
provider: str
provider_account_id: str
class Token(BaseModel):
access_token: str
token_type: str
class FrontEndUser(User):
user_id: UUID
enabled: bool