mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-04 17:20:19 +08:00
857: Implement backend Discord authentication (#943)
* Initial code for backend auth * Remove outdated check * Initial separation of AuthenticatedUser * AuthenticatedUser -> Account * Rework for Account * Initial code for Discord OAuth * Remove now redundant methods * Remove incorrect response model, add requests dep for backend * Create Settings fields for Discord values * Cleanup get account from Discord function * Cleanup * Cleanup * Generate alembic upgrade script * Remove unused error codes * Update alembic script to correct down revision * Use aiohttp over requests * Update alembic script to latest down revision
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""Add Account table
|
||||
|
||||
Revision ID: 8c8241d1f973
|
||||
Revises: 4d7e0b0ebe84
|
||||
Create Date: 2023-01-30 15:10:58.776315
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8c8241d1f973"
|
||||
down_revision = "4d7e0b0ebe84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"account",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False),
|
||||
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
||||
sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.Column("provider_account_id", sqlmodel.sql.sqltypes.AutoString(length=128), nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("id", "account", [], unique=True)
|
||||
op.create_index("provider", "account", ["provider_account_id"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("provider", table_name="account")
|
||||
op.drop_index("id", table_name="account")
|
||||
op.drop_table("account")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,73 @@
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from oasst_backend import auth
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/discord")
|
||||
def login_discord(request: Request):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
auth_url = f"https://discord.com/api/oauth2/authorize?client_id={Settings.AUTH_DISCORD_CLIENT_ID}&redirect_uri={redirect_uri}&response_type=code&scope=identify"
|
||||
raise HTTPException(status_code=302, headers={"location": auth_url})
|
||||
|
||||
|
||||
@router.get("/callback/discord", response_model=protocol_schema.Token)
|
||||
async def callback_discord(
|
||||
auth_code: str,
|
||||
request: Request,
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
redirect_uri = f"{get_callback_uri(request)}/discord"
|
||||
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
# Exchange the auth code for a Discord access token
|
||||
async with session.post(
|
||||
"https://discord.com/api/oauth2/token",
|
||||
data={
|
||||
"client_id": Settings.AUTH_DISCORD_CLIENT_ID,
|
||||
"client_secret": Settings.AUTH_DISCORD_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": "identify",
|
||||
},
|
||||
) as token_response:
|
||||
token_response_json = await token_response.json()
|
||||
access_token = token_response_json["access_token"]
|
||||
|
||||
# Retrieve user's Discord information using access token
|
||||
async with session.get(
|
||||
"https://discord.com/api/users/@me", headers={"Authorization": f"Bearer {access_token}"}
|
||||
) as user_response:
|
||||
user_response_json = await user_response.json()
|
||||
discord_id = user_response_json["id"]
|
||||
|
||||
account: Account = auth.get_account_from_discord_id(db, discord_id)
|
||||
|
||||
if not account:
|
||||
# Discord account is not linked to an OA account
|
||||
raise OasstError("Invalid authentication", OasstErrorCode.INVALID_AUTHENTICATION, HTTP_401_UNAUTHORIZED)
|
||||
|
||||
# Discord account is valid and linked to an OA account -> create JWT
|
||||
access_token = auth.create_access_token(account)
|
||||
|
||||
return protocol_schema.Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
def get_callback_uri(request: Request):
|
||||
"""
|
||||
Gets the URI for the base callback endpoint with no provider name appended.
|
||||
"""
|
||||
# This seems ugly, not sure if there is a better way
|
||||
current_url = str(request.url)
|
||||
domain = current_url.split("/api/v1/")[0]
|
||||
redirect_uri = f"{domain}/api/v1/callback"
|
||||
return redirect_uri
|
||||
@@ -0,0 +1,37 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from jose import jwt
|
||||
from oasst_backend.config import Settings
|
||||
from oasst_backend.models import Account
|
||||
from sqlmodel import Session
|
||||
|
||||
|
||||
def create_access_token(data: dict) -> str:
|
||||
"""
|
||||
Create an encoded JSON Web Token (JWT) using the given data.
|
||||
"""
|
||||
|
||||
expires_delta = timedelta(minutes=Settings.AUTH_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, Settings.AUTH_SECRET, algorithm=Settings.AUTH_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def get_account_from_discord_id(db: Session, discord_id: str) -> Optional[Account]:
|
||||
"""
|
||||
Get the Open-Assistant Account associated with the given Discord ID.
|
||||
"""
|
||||
|
||||
account: Account = (
|
||||
db.query(Account)
|
||||
.filter(
|
||||
Account.provider == "discord",
|
||||
Account.provider_account_id == discord_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return account
|
||||
@@ -135,6 +135,11 @@ class Settings(BaseSettings):
|
||||
AUTH_LENGTH: int = 32
|
||||
AUTH_SECRET: bytes = b"O/M2uIbGj+lDD2oyNa8ax4jEOJqCPJzO53UbWShmq98="
|
||||
AUTH_COOKIE_NAME: str = "next-auth.session-token"
|
||||
AUTH_ALGORITHM: str = "HS256"
|
||||
AUTH_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
|
||||
AUTH_DISCORD_CLIENT_ID: str = ""
|
||||
AUTH_DISCORD_CLIENT_SECRET: str = ""
|
||||
|
||||
POSTGRES_HOST: str = "localhost"
|
||||
POSTGRES_PORT: str = "5432"
|
||||
|
||||
@@ -60,3 +60,20 @@ class User(SQLModel, table=True):
|
||||
last_activity_date=self.last_activity_date,
|
||||
tos_acceptance_date=self.tos_acceptance_date,
|
||||
)
|
||||
|
||||
|
||||
class Account(SQLModel, table=True):
|
||||
__tablename__ = "account"
|
||||
__table_args__ = (
|
||||
Index("id", unique=True),
|
||||
Index("provider", "provider_account_id", unique=True),
|
||||
)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
user_id: UUID = Field(foreign_key="user.id")
|
||||
provider: str = Field(nullable=False, max_length=128, default="email") # discord or email
|
||||
provider_account_id: str = Field(nullable=False, max_length=128)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
aiohttp==3.8.3
|
||||
alembic==1.8.1
|
||||
cryptography==39.0.0
|
||||
fastapi==0.88.0
|
||||
|
||||
@@ -28,6 +28,8 @@ class OasstErrorCode(IntEnum):
|
||||
SERVER_ERROR0 = 500
|
||||
SERVER_ERROR1 = 501
|
||||
|
||||
INVALID_AUTHENTICATION = 600
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
TASK_ACK_FAILED = 1001
|
||||
|
||||
@@ -29,6 +29,17 @@ class User(BaseModel):
|
||||
auth_method: Literal["discord", "local", "system"]
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
id: UUID
|
||||
provider: str
|
||||
provider_account_id: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class FrontEndUser(User):
|
||||
user_id: UUID
|
||||
enabled: bool
|
||||
|
||||
Reference in New Issue
Block a user