from typing import Annotated, Dict, Optional from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import jwt, JWTError from app.core.config import settings from app.core.security import ALGORITHM from sqlmodel import Session from app.core.database import get_session from app.models.user import User oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/login", auto_error=False) def get_token_payload(token: Annotated[str, Depends(oauth2_scheme)]) -> Dict: if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM]) return payload except JWTError: raise credentials_exception async def get_current_user_token(token: Annotated[str, Depends(oauth2_scheme)]) -> Dict: return get_token_payload(token) async def get_current_admin(token: Annotated[str, Depends(oauth2_scheme)]) -> bool: from app.models.user import UserRole payload = get_token_payload(token) role: str = payload.get("role") # Check for both "admin" and "ADMIN" for robust authorization if role not in ["admin", "ADMIN", UserRole.ADMIN]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="The user doesn't have enough privileges", ) return True async def get_current_user( token: Annotated[str, Depends(oauth2_scheme)], session: Session = Depends(get_session) ) -> User: payload = get_token_payload(token) user_id = payload.get("sub") if not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", ) user = session.get(User, user_id) if not user: raise HTTPException(status_code=404, detail="User not found") return user async def get_current_user_optional( token: Optional[str] = Depends(oauth2_scheme), session: Session = Depends(get_session) ) -> Optional[User]: if not token: return None try: payload = jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM]) user_id = payload.get("sub") if not user_id: return None return session.get(User, user_id) except JWTError: return None