Tech Stack Advisor - Code Viewer

← Back to File Tree

auth.py

Language: python | Path: backend/src/core/auth.py | Lines: 280
"""JWT-based authentication module."""
import jwt
import time
from typing import Optional, Dict, Any
from datetime import datetime, timedelta

from .config import settings
from .logging import get_logger
from .user_memory import get_user_memory_store

logger = get_logger(__name__)

# JWT settings
JWT_SECRET = settings.jwt_secret if hasattr(settings, "jwt_secret") else "your-secret-key-change-in-production"
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24

# In-memory token storage (for logout/revocation)
# In production, use Redis or similar
_active_tokens: Dict[str, Dict[str, Any]] = {}


def create_access_token(user_id: str, email: str, is_admin: bool = False) -> str:
    """Create a JWT access token.

    Args:
        user_id: User ID
        email: User email
        is_admin: Whether user is admin

    Returns:
        JWT token string
    """
    current_time = time.time()
    expiration_time = current_time + (JWT_EXPIRATION_HOURS * 3600)

    payload = {
        "user_id": user_id,
        "email": email,
        "is_admin": is_admin,
        "exp": expiration_time,
        "iat": current_time,
    }

    token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)

    # Store token for revocation check
    _active_tokens[token] = {
        "user_id": user_id,
        "created_at": current_time,
        "expires_at": expiration_time,
    }

    logger.info("token_created", user_id=user_id, expires_at=datetime.fromtimestamp(expiration_time).isoformat())
    return token


def verify_token(token: str) -> Optional[Dict[str, Any]]:
    """Verify and decode a JWT token.

    Args:
        token: JWT token string

    Returns:
        Decoded payload if valid, None otherwise
    """
    try:
        # Check if token was revoked
        if token not in _active_tokens:
            logger.warning("token_revoked_or_invalid")
            return None

        # Decode and verify
        payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])

        logger.info("token_verified", user_id=payload.get("user_id"))
        return payload

    except jwt.ExpiredSignatureError:
        logger.warning("token_expired")
        # Remove from active tokens
        _active_tokens.pop(token, None)
        return None

    except jwt.InvalidTokenError as e:
        logger.warning("token_invalid", error=str(e))
        return None


def revoke_token(token: str) -> bool:
    """Revoke a token (logout).

    Args:
        token: JWT token to revoke

    Returns:
        True if revoked, False if not found
    """
    if token in _active_tokens:
        user_id = _active_tokens[token]["user_id"]
        del _active_tokens[token]
        logger.info("token_revoked", user_id=user_id)
        return True
    return False


def authenticate_user(email: str, password: str) -> Optional[Dict[str, Any]]:
    """Authenticate a user with email/password.

    Args:
        email: User email
        password: User password

    Returns:
        User data and token if successful, None otherwise
    """
    memory_store = get_user_memory_store()
    user = memory_store.verify_password(email, password)

    if user:
        token = create_access_token(
            user_id=user["user_id"],
            email=user["email"],
            is_admin=user.get("is_admin", False),
        )

        logger.info("user_authenticated", user_id=user["user_id"], email=email)

        return {
            "user": {
                "user_id": user["user_id"],
                "email": user["email"],
                "full_name": user.get("full_name"),
                "is_admin": user.get("is_admin", False),
            },
            "token": token,
        }

    logger.warning("authentication_failed", email=email)
    return None


def register_user(
    email: str,
    password: str,
    full_name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
    """Register a new user.

    Args:
        email: User email
        password: User password
        full_name: User's full name

    Returns:
        User data and token if successful, None if user exists
    """
    memory_store = get_user_memory_store()

    try:
        user = memory_store.create_user(
            email=email,
            password=password,
            full_name=full_name,
        )

        token = create_access_token(
            user_id=user["user_id"],
            email=user["email"],
            is_admin=user.get("is_admin", False),
        )

        logger.info("user_registered", user_id=user["user_id"], email=email)

        return {
            "user": {
                "user_id": user["user_id"],
                "email": user["email"],
                "full_name": user.get("full_name"),
                "is_admin": user.get("is_admin", False),
            },
            "token": token,
        }

    except ValueError as e:
        logger.warning("registration_failed", email=email, error=str(e))
        return None


def authenticate_oauth(
    email: str,
    oauth_provider: str,
    oauth_id: str,
    full_name: Optional[str] = None,
) -> Dict[str, Any]:
    """Authenticate or create user via OAuth.

    Args:
        email: User email from OAuth provider
        oauth_provider: Provider name (google, microsoft, linkedin, facebook)
        oauth_id: User ID from OAuth provider
        full_name: User's full name from OAuth

    Returns:
        User data and token
    """
    memory_store = get_user_memory_store()

    # Try to get existing user
    user = memory_store.get_user_by_email(email)

    if not user:
        # Create new user
        user = memory_store.create_user(
            email=email,
            full_name=full_name,
            oauth_provider=oauth_provider,
            oauth_id=oauth_id,
        )
        logger.info("oauth_user_created", email=email, provider=oauth_provider)
    else:
        # Update last login
        memory_store.update_user(user["user_id"], {"last_login": time.time()})
        logger.info("oauth_user_login", email=email, provider=oauth_provider)

    token = create_access_token(
        user_id=user["user_id"],
        email=user["email"],
        is_admin=user.get("is_admin", False),
    )

    return {
        "user": {
            "user_id": user["user_id"],
            "email": user["email"],
            "full_name": user.get("full_name"),
            "is_admin": user.get("is_admin", False),
        },
        "token": token,
    }


def cleanup_expired_tokens() -> int:
    """Remove expired tokens from memory.

    Returns:
        Number of tokens removed
    """
    current_time = time.time()
    expired = [
        token for token, data in _active_tokens.items()
        if data["expires_at"] < current_time
    ]

    for token in expired:
        del _active_tokens[token]

    if expired:
        logger.info("expired_tokens_cleaned", count=len(expired))

    return len(expired)


def get_current_user(token: str) -> Optional[Dict[str, Any]]:
    """Get current user from token.

    Args:
        token: JWT token

    Returns:
        User data if valid token, None otherwise
    """
    payload = verify_token(token)
    if not payload:
        return None

    memory_store = get_user_memory_store()
    user = memory_store.get_user_by_id(payload["user_id"])

    return user