"""JWT-based authentication module."""
import jwt
import time
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from .config import settings
from .logging import get_logger
from .user_memory import get_user_memory_store
logger = get_logger(__name__)
# JWT settings
JWT_SECRET = settings.jwt_secret if hasattr(settings, "jwt_secret") else "your-secret-key-change-in-production"
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_HOURS = 24
# In-memory token storage (for logout/revocation)
# In production, use Redis or similar
_active_tokens: Dict[str, Dict[str, Any]] = {}
def create_access_token(user_id: str, email: str, is_admin: bool = False) -> str:
"""Create a JWT access token.
Args:
user_id: User ID
email: User email
is_admin: Whether user is admin
Returns:
JWT token string
"""
current_time = time.time()
expiration_time = current_time + (JWT_EXPIRATION_HOURS * 3600)
payload = {
"user_id": user_id,
"email": email,
"is_admin": is_admin,
"exp": expiration_time,
"iat": current_time,
}
token = jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
# Store token for revocation check
_active_tokens[token] = {
"user_id": user_id,
"created_at": current_time,
"expires_at": expiration_time,
}
logger.info("token_created", user_id=user_id, expires_at=datetime.fromtimestamp(expiration_time).isoformat())
return token
def verify_token(token: str) -> Optional[Dict[str, Any]]:
"""Verify and decode a JWT token.
Args:
token: JWT token string
Returns:
Decoded payload if valid, None otherwise
"""
try:
# Check if token was revoked
if token not in _active_tokens:
logger.warning("token_revoked_or_invalid")
return None
# Decode and verify
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
logger.info("token_verified", user_id=payload.get("user_id"))
return payload
except jwt.ExpiredSignatureError:
logger.warning("token_expired")
# Remove from active tokens
_active_tokens.pop(token, None)
return None
except jwt.InvalidTokenError as e:
logger.warning("token_invalid", error=str(e))
return None
def revoke_token(token: str) -> bool:
"""Revoke a token (logout).
Args:
token: JWT token to revoke
Returns:
True if revoked, False if not found
"""
if token in _active_tokens:
user_id = _active_tokens[token]["user_id"]
del _active_tokens[token]
logger.info("token_revoked", user_id=user_id)
return True
return False
def authenticate_user(email: str, password: str) -> Optional[Dict[str, Any]]:
"""Authenticate a user with email/password.
Args:
email: User email
password: User password
Returns:
User data and token if successful, None otherwise
"""
memory_store = get_user_memory_store()
user = memory_store.verify_password(email, password)
if user:
token = create_access_token(
user_id=user["user_id"],
email=user["email"],
is_admin=user.get("is_admin", False),
)
logger.info("user_authenticated", user_id=user["user_id"], email=email)
return {
"user": {
"user_id": user["user_id"],
"email": user["email"],
"full_name": user.get("full_name"),
"is_admin": user.get("is_admin", False),
},
"token": token,
}
logger.warning("authentication_failed", email=email)
return None
def register_user(
email: str,
password: str,
full_name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Register a new user.
Args:
email: User email
password: User password
full_name: User's full name
Returns:
User data and token if successful, None if user exists
"""
memory_store = get_user_memory_store()
try:
user = memory_store.create_user(
email=email,
password=password,
full_name=full_name,
)
token = create_access_token(
user_id=user["user_id"],
email=user["email"],
is_admin=user.get("is_admin", False),
)
logger.info("user_registered", user_id=user["user_id"], email=email)
return {
"user": {
"user_id": user["user_id"],
"email": user["email"],
"full_name": user.get("full_name"),
"is_admin": user.get("is_admin", False),
},
"token": token,
}
except ValueError as e:
logger.warning("registration_failed", email=email, error=str(e))
return None
def authenticate_oauth(
email: str,
oauth_provider: str,
oauth_id: str,
full_name: Optional[str] = None,
) -> Dict[str, Any]:
"""Authenticate or create user via OAuth.
Args:
email: User email from OAuth provider
oauth_provider: Provider name (google, microsoft, linkedin, facebook)
oauth_id: User ID from OAuth provider
full_name: User's full name from OAuth
Returns:
User data and token
"""
memory_store = get_user_memory_store()
# Try to get existing user
user = memory_store.get_user_by_email(email)
if not user:
# Create new user
user = memory_store.create_user(
email=email,
full_name=full_name,
oauth_provider=oauth_provider,
oauth_id=oauth_id,
)
logger.info("oauth_user_created", email=email, provider=oauth_provider)
else:
# Update last login
memory_store.update_user(user["user_id"], {"last_login": time.time()})
logger.info("oauth_user_login", email=email, provider=oauth_provider)
token = create_access_token(
user_id=user["user_id"],
email=user["email"],
is_admin=user.get("is_admin", False),
)
return {
"user": {
"user_id": user["user_id"],
"email": user["email"],
"full_name": user.get("full_name"),
"is_admin": user.get("is_admin", False),
},
"token": token,
}
def cleanup_expired_tokens() -> int:
"""Remove expired tokens from memory.
Returns:
Number of tokens removed
"""
current_time = time.time()
expired = [
token for token, data in _active_tokens.items()
if data["expires_at"] < current_time
]
for token in expired:
del _active_tokens[token]
if expired:
logger.info("expired_tokens_cleaned", count=len(expired))
return len(expired)
def get_current_user(token: str) -> Optional[Dict[str, Any]]:
"""Get current user from token.
Args:
token: JWT token
Returns:
User data if valid token, None otherwise
"""
payload = verify_token(token)
if not payload:
return None
memory_store = get_user_memory_store()
user = memory_store.get_user_by_id(payload["user_id"])
return user