Tech Stack Advisor - Code Viewer

← Back to File Tree

main.py

Language: python | Path: backend/src/api/main.py | Lines: 1178
"""FastAPI application for Tech Stack Advisor."""
import time
import secrets
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator
from pathlib import Path

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse, RedirectResponse, HTMLResponse, Response
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from prometheus_client import Counter, Histogram, Gauge, generate_latest, REGISTRY

from ..core.config import settings
from ..core.logging import setup_logging, get_logger, usage_tracker
from ..core.session import SessionStore
from ..core.auth import authenticate_user, register_user, verify_token, revoke_token, get_current_user, authenticate_oauth
from ..core.google_oauth import get_google_auth_url, exchange_code_for_token, get_user_info, is_oauth_configured
from ..core.user_memory import get_user_memory_store
from ..orchestration import TechStackOrchestrator
from ..agents import ConversationManager
from .models import (
    RecommendationRequest,
    RecommendationResponse,
    HealthResponse,
    MetricsResponse,
    ConversationStartRequest,
    ConversationStartResponse,
    ConversationMessageRequest,
    ConversationMessageResponse,
    ConversationStatusResponse,
    RegisterRequest,
    LoginRequest,
    AuthResponse,
    FeedbackRequest,
    SimilarQueriesResponse,
    AdminStatsResponse,
)

# Setup logging
setup_logging()
logger = get_logger(__name__)

# Prometheus Metrics - Gracefully handle re-registration in production
# This approach works with multiple workers and hot reloads
try:
    http_requests_total = Counter(
        'http_requests_total',
        'Total HTTP requests',
        ['method', 'endpoint', 'status_code']
    )

    http_request_duration_seconds = Histogram(
        'http_request_duration_seconds',
        'HTTP request duration in seconds',
        ['method', 'endpoint']
    )

    llm_tokens_total = Counter(
        'llm_tokens_total',
        'Total LLM tokens used',
        ['agent', 'token_type']
    )

    llm_cost_usd_total = Counter(
        'llm_cost_usd_total',
        'Total LLM cost in USD',
        ['agent']
    )

    llm_requests_total = Counter(
        'llm_requests_total',
        'Total LLM requests',
        ['agent', 'status']
    )

    active_conversation_sessions = Gauge(
        'active_conversation_sessions',
        'Number of active conversation sessions'
    )

    user_registrations_total = Counter(
        'user_registrations_total',
        'Total user registrations',
        ['oauth_provider']
    )

    user_logins_total = Counter(
        'user_logins_total',
        'Total user logins',
        ['oauth_provider']
    )

    recommendations_total = Counter(
        'recommendations_total',
        'Total recommendations generated',
        ['status', 'authenticated']
    )
except ValueError as e:
    # Metrics already registered - retrieve them from the registry
    # This happens in production with multiple workers or hot reloads
    if 'Duplicated timeseries' in str(e):
        from prometheus_client import REGISTRY

        # Get existing metrics from registry
        http_requests_total = REGISTRY._names_to_collectors.get('http_requests_total')
        http_request_duration_seconds = REGISTRY._names_to_collectors.get('http_request_duration_seconds')
        llm_tokens_total = REGISTRY._names_to_collectors.get('llm_tokens_total')
        llm_cost_usd_total = REGISTRY._names_to_collectors.get('llm_cost_usd_total')
        llm_requests_total = REGISTRY._names_to_collectors.get('llm_requests_total')
        active_conversation_sessions = REGISTRY._names_to_collectors.get('active_conversation_sessions')
        user_registrations_total = REGISTRY._names_to_collectors.get('user_registrations_total')
        user_logins_total = REGISTRY._names_to_collectors.get('user_logins_total')
        recommendations_total = REGISTRY._names_to_collectors.get('recommendations_total')

        logger.info("prometheus_metrics_reused", message="Reusing existing Prometheus metrics from registry")
    else:
        raise

# Global state
app_state: dict[str, Any] = {
    "start_time": time.time(),
    "total_requests": 0,
    "orchestrator": None,
    "conversation_manager": None,
}

# OAuth state storage (for CSRF protection)
# In production, use Redis or similar
oauth_states: dict[str, float] = {}


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    """Application lifespan manager."""
    # Startup
    logger.info("api_startup", environment=settings.environment)
    app_state["orchestrator"] = TechStackOrchestrator()
    app_state["conversation_manager"] = ConversationManager()
    app_state["start_time"] = time.time()
    logger.info("orchestrator_loaded", agents=4)
    logger.info("conversation_manager_loaded")

    yield

    # Shutdown
    logger.info("api_shutdown")
    # Cleanup expired sessions
    SessionStore.cleanup_expired_sessions()


# Initialize FastAPI app
app = FastAPI(
    title="Tech Stack Advisor API",
    description="Multi-agent RAG system for intelligent tech stack recommendations",
    version="0.1.0",
    lifespan=lifespan,
)

# Rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure appropriately for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.middleware("http")
async def log_requests(request: Request, call_next: Any) -> Any:
    """Log all requests and track Prometheus metrics."""
    start_time = time.time()
    response = await call_next(request)
    duration = time.time() - start_time

    # Track Prometheus metrics
    http_requests_total.labels(
        method=request.method,
        endpoint=request.url.path,
        status_code=response.status_code
    ).inc()

    http_request_duration_seconds.labels(
        method=request.method,
        endpoint=request.url.path
    ).observe(duration)

    # Update active sessions gauge
    from ..core.session import SessionStore
    active_conversation_sessions.set(SessionStore.get_active_session_count())

    # Structured logging
    logger.info(
        "http_request",
        method=request.method,
        path=request.url.path,
        status_code=response.status_code,
        duration_ms=round(duration * 1000, 2),
    )

    app_state["total_requests"] += 1
    return response


@app.get("/", tags=["root"])
async def root():
    """Serve the main web UI."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    index_path = static_dir / "index.html"

    if not index_path.exists():
        logger.error(f"index.html not found at {index_path}")
        raise HTTPException(status_code=500, detail="Frontend not found")

    return FileResponse(index_path)


@app.get("/login.html", tags=["root"])
async def login_page():
    """Serve the login page."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    login_path = static_dir / "login.html"

    if not login_path.exists():
        raise HTTPException(status_code=404, detail="Login page not found")

    return FileResponse(login_path)


@app.get("/register.html", tags=["root"])
async def register_page():
    """Serve the registration page."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    register_path = static_dir / "register.html"

    if not register_path.exists():
        raise HTTPException(status_code=404, detail="Registration page not found")

    return FileResponse(register_path)


@app.get("/auth.js", tags=["root"])
async def auth_js():
    """Serve the authentication JavaScript file."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    auth_js_path = static_dir / "auth.js"

    if not auth_js_path.exists():
        raise HTTPException(status_code=404, detail="auth.js not found")

    return FileResponse(auth_js_path, media_type="text/javascript")


@app.get("/test-auth.html", tags=["root"])
async def test_auth_page():
    """Serve the authentication test page."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    test_path = static_dir / "test-auth.html"

    if not test_path.exists():
        raise HTTPException(status_code=404, detail="test-auth.html not found")

    return FileResponse(test_path)


@app.get("/admin.html", tags=["root"])
async def admin_page():
    """Serve the admin dashboard page."""
    static_dir = Path(__file__).parent.parent.parent / "static"
    admin_path = static_dir / "admin.html"

    if not admin_path.exists():
        raise HTTPException(status_code=404, detail="admin.html not found")

    return FileResponse(admin_path)


@app.get("/api", tags=["root"])
async def api_info() -> dict[str, str]:
    """API information endpoint."""
    return {
        "message": "Tech Stack Advisor API",
        "version": "0.1.0",
        "docs": "/docs",
        "health": "/health",
    }


@app.get("/health", response_model=HealthResponse, tags=["monitoring"])
async def health_check() -> HealthResponse:
    """Health check endpoint."""
    uptime = time.time() - app_state["start_time"]

    return HealthResponse(
        status="healthy",
        version="0.1.0",
        agents_loaded=4,
        uptime_seconds=round(uptime, 2),
    )


@app.get("/metrics", response_model=MetricsResponse, tags=["monitoring"])
async def get_metrics() -> MetricsResponse:
    """Get API metrics and usage statistics."""
    budget_remaining = max(0, settings.daily_budget_usd - usage_tracker.daily_cost)

    return MetricsResponse(
        total_requests=app_state["total_requests"],
        total_tokens=usage_tracker.daily_tokens,
        total_cost_usd=round(usage_tracker.daily_cost, 4),
        daily_queries=usage_tracker.daily_queries,
        daily_cost_usd=round(usage_tracker.daily_cost, 4),
        budget_remaining_usd=round(budget_remaining, 4),
    )


@app.get("/metrics/prometheus", tags=["monitoring"])
async def prometheus_metrics():
    """Expose Prometheus metrics for Grafana Cloud scraping.

    This endpoint provides metrics in Prometheus format that can be scraped by
    Grafana Cloud or any Prometheus-compatible monitoring system.

    Metrics exposed:
    - http_requests_total: Total HTTP requests by method, endpoint, and status
    - http_request_duration_seconds: Request duration histogram
    - llm_tokens_total: LLM token usage by agent and type (input/output)
    - llm_cost_usd_total: LLM cost in USD by agent
    - llm_requests_total: Total LLM requests by agent and status
    - active_conversation_sessions: Current active conversation sessions
    - user_registrations_total: Total user registrations by OAuth provider
    - user_logins_total: Total user logins by OAuth provider
    - recommendations_total: Total recommendations by status and auth state

    Returns:
        Prometheus-formatted metrics text
    """
    return Response(generate_latest(REGISTRY), media_type="text/plain")


@app.post(
    "/recommend",
    response_model=RecommendationResponse,
    tags=["recommendations"],
)
@limiter.limit(settings.rate_limit_demo)
async def get_recommendation(
    request: Request,
    req: RecommendationRequest,
) -> RecommendationResponse:
    """Get tech stack recommendations based on user query.

    This endpoint orchestrates 4 specialized AI agents:
    1. Database Agent - Recommends databases and scaling strategies
    2. Infrastructure Agent - Suggests cloud architecture and deployment
    3. Cost Agent - Provides cost estimates across providers
    4. Security Agent - Performs threat modeling and compliance checks

    Rate Limits:
    - Demo mode (no API key): 5 requests/hour per IP
    - Authenticated mode: 50 requests/hour

    Args:
        request: FastAPI request object
        req: Recommendation request with user query

    Returns:
        Comprehensive tech stack recommendation from all agents

    Raises:
        HTTPException: If workflow fails or budget is exceeded
    """
    # Check for optional authentication (long-term memory)
    token = get_token_from_header(request)
    current_user = None
    if token:
        current_user = get_current_user(token)
        if current_user:
            logger.info("recommendation_request_authenticated", user_id=current_user["user_id"], query=req.query[:100])
        else:
            logger.warning("invalid_token_ignored")
    else:
        logger.info("recommendation_request", query=req.query[:100], using_custom_key=req.api_key is not None)

    # Only check budget/query limits if NOT authenticated and NOT using custom API key
    if not current_user and not req.api_key:
        # Check daily budget
        if usage_tracker.daily_cost >= settings.daily_budget_usd:
            logger.warning("daily_budget_exceeded", cost=usage_tracker.daily_cost)
            raise HTTPException(
                status_code=429,
                detail=f"Daily budget of ${settings.daily_budget_usd} exceeded. "
                f"Current cost: ${usage_tracker.daily_cost:.2f}",
            )

        # Check daily query cap
        if usage_tracker.daily_queries >= settings.daily_query_cap:
            logger.warning("daily_query_cap_exceeded", queries=usage_tracker.daily_queries)
            raise HTTPException(
                status_code=429,
                detail=f"Daily query cap of {settings.daily_query_cap} reached. "
                "Try again tomorrow or register for unlimited queries.",
            )

    try:
        orchestrator: TechStackOrchestrator = app_state["orchestrator"]

        # Track tokens before query
        tokens_before = usage_tracker.daily_tokens

        # Process the query with optional DAU override and API key
        result = await orchestrator.process_query(req.query, dau_override=req.dau, api_key=req.api_key)

        # Calculate tokens and cost for this query
        tokens_used = usage_tracker.daily_tokens - tokens_before
        # Rough cost estimate: $0.015 per 1M input tokens, $0.075 per 1M output tokens (Haiku)
        estimated_cost = (tokens_used / 1_000_000) * 0.03  # Average

        # Convert to response model
        if result.get("status") == "success":
            response = RecommendationResponse(
                status=result["status"],
                query=result["query"],
                correlation_id=result["correlation_id"],
                parsed_context=result.get("parsed_context"),
                recommendations=result.get("recommendations"),
            )
        else:
            response = RecommendationResponse(
                status="error",
                query=req.query,
                correlation_id=result.get("correlation_id", "unknown"),
                error=result.get("error", "Unknown error"),
            )

        # Track Prometheus metrics
        recommendations_total.labels(
            status=response.status,
            authenticated=str(current_user is not None)
        ).inc()

        # Store in long-term memory if authenticated
        if current_user:
            try:
                memory_store = get_user_memory_store()
                query_id = memory_store.store_query(
                    user_id=current_user["user_id"],
                    query=req.query,
                    correlation_id=response.correlation_id,
                    dau=req.dau,
                    parsed_context=result.get("parsed_context"),
                    recommendations=result.get("recommendations"),
                    status=response.status,
                    error_message=response.error,
                    tokens_used=tokens_used,
                    cost_usd=estimated_cost,
                )
                logger.info(
                    "query_stored_in_memory",
                    user_id=current_user["user_id"],
                    query_id=query_id,
                )
            except Exception as e:
                logger.error("failed_to_store_query", error=str(e))
                # Don't fail the request if memory storage fails

        logger.info(
            "recommendation_complete",
            correlation_id=response.correlation_id,
            status=response.status,
            authenticated=current_user is not None,
        )

        return response

    except Exception as e:
        logger.error("recommendation_error", error=str(e))
        raise HTTPException(
            status_code=500,
            detail=f"Internal server error: {str(e)}",
        )


@app.post("/generate-diagram", tags=["recommendations"])
@limiter.limit(settings.rate_limit_demo)
async def generate_architecture_diagram(
    request: Request,
    req: dict,
) -> JSONResponse:
    """Generate architecture diagram for infrastructure recommendations.

    Args:
        request: FastAPI request object
        req: Dictionary with user_query, recommendations, scale_tier, api_key

    Returns:
        Mermaid diagram code
    """
    logger.info("diagram_request", has_custom_key=req.get("api_key") is not None)

    try:
        orchestrator: TechStackOrchestrator = app_state["orchestrator"]
        infrastructure_agent = orchestrator.infrastructure_agent

        # Generate diagram
        result = await infrastructure_agent.generate_diagram({
            "user_query": req.get("user_query", ""),
            "recommendations": req.get("recommendations", ""),
            "scale_tier": req.get("scale_tier", "STARTER"),
            "api_key": req.get("api_key"),
        })

        logger.info("diagram_generated")

        return JSONResponse(content=result)

    except Exception as e:
        logger.error("diagram_error", error=str(e))
        raise HTTPException(
            status_code=500,
            detail=f"Diagram generation error: {str(e)}",
        )


# Conversation API endpoints
@app.post(
    "/conversation/start",
    response_model=ConversationStartResponse,
    tags=["conversation"],
)
@limiter.limit(settings.rate_limit_demo)
async def start_conversation(
    request: Request,
    req: ConversationStartRequest,
) -> ConversationStartResponse:
    """Start a new conversational session.

    This initiates an interactive conversation to gather detailed project requirements
    through intelligent follow-up questions. Demonstrates short-term memory (session-based).

    Args:
        request: FastAPI request object
        req: Initial message from user

    Returns:
        Session ID, first question, and extracted context
    """
    logger.info("conversation_start_request", message=req.initial_message[:100])

    try:
        conversation_manager: ConversationManager = app_state["conversation_manager"]

        # Start conversation
        result = await conversation_manager.start_conversation(
            initial_input=req.initial_message,
            api_key=req.api_key
        )

        # Create session
        session_id = SessionStore.create_session()

        # Store initial state
        SessionStore.add_message(session_id, "user", req.initial_message)
        SessionStore.add_message(session_id, "assistant", result["next_question"])
        SessionStore.update_session(session_id, {
            "extracted_context": result["extracted_context"],
            "completion_percentage": result["completion_percentage"],
        })

        logger.info(
            "conversation_started",
            session_id=session_id,
            completion=result["completion_percentage"]
        )

        return ConversationStartResponse(
            session_id=session_id,
            question=result["next_question"],
            question_type=result.get("question_type"),
            options=result.get("options"),
            extracted_context=result["extracted_context"],
            completion_percentage=result["completion_percentage"],
        )

    except Exception as e:
        logger.error("conversation_start_error", error=str(e))
        raise HTTPException(
            status_code=500,
            detail=f"Failed to start conversation: {str(e)}",
        )


@app.post(
    "/conversation/message",
    response_model=ConversationMessageResponse,
    tags=["conversation"],
)
@limiter.limit(settings.rate_limit_demo)
async def send_conversation_message(
    request: Request,
    req: ConversationMessageRequest,
) -> ConversationMessageResponse:
    """Send a message in an ongoing conversation.

    Continues the conversation, extracts more context, and determines if enough
    information has been gathered to generate recommendations.

    Args:
        request: FastAPI request object
        req: User message and session ID

    Returns:
        Next question (or ready flag), updated context, and conversation history
    """
    logger.info("conversation_message_request", session_id=req.session_id)

    # Get session
    session = SessionStore.get_session(req.session_id)
    if not session:
        raise HTTPException(
            status_code=404,
            detail="Session not found or expired"
        )

    try:
        conversation_manager: ConversationManager = app_state["conversation_manager"]

        # Continue conversation
        result = await conversation_manager.continue_conversation(
            conversation_history=session["conversation_history"],
            user_response=req.message,
            current_context=session["extracted_context"],
            api_key=req.api_key
        )

        # Update session
        SessionStore.add_message(req.session_id, "user", req.message)

        if not result["ready_for_recommendation"] and result.get("next_question"):
            SessionStore.add_message(req.session_id, "assistant", result["next_question"])

        SessionStore.update_session(req.session_id, {
            "extracted_context": result["extracted_context"],
            "completion_percentage": result["completion_percentage"],
            "ready_for_recommendation": result["ready_for_recommendation"],
        })

        # Get updated session
        session = SessionStore.get_session(req.session_id)

        logger.info(
            "conversation_message_processed",
            session_id=req.session_id,
            completion=result["completion_percentage"],
            ready=result["ready_for_recommendation"]
        )

        return ConversationMessageResponse(
            session_id=req.session_id,
            question=result.get("next_question"),
            question_type=result.get("question_type"),
            options=result.get("options"),
            extracted_context=result["extracted_context"],
            completion_percentage=result["completion_percentage"],
            ready_for_recommendation=result["ready_for_recommendation"],
            conversation_history=session["conversation_history"],
        )

    except Exception as e:
        logger.error("conversation_message_error", error=str(e), session_id=req.session_id)
        raise HTTPException(
            status_code=500,
            detail=f"Failed to process message: {str(e)}",
        )


@app.get(
    "/conversation/{session_id}",
    response_model=ConversationStatusResponse,
    tags=["conversation"],
)
async def get_conversation_status(session_id: str) -> ConversationStatusResponse:
    """Get the status of a conversation session.

    Retrieves the current state of a conversation, including history and extracted context.

    Args:
        session_id: Session ID

    Returns:
        Conversation status and history
    """
    logger.info("conversation_status_request", session_id=session_id)

    session = SessionStore.get_session(session_id)

    if not session:
        return ConversationStatusResponse(
            session_id=session_id,
            exists=False,
            conversation_history=None,
            extracted_context=None,
            completion_percentage=None,
            ready_for_recommendation=None,
        )

    return ConversationStatusResponse(
        session_id=session_id,
        exists=True,
        conversation_history=session["conversation_history"],
        extracted_context=session["extracted_context"],
        completion_percentage=session["completion_percentage"],
        ready_for_recommendation=session["ready_for_recommendation"],
    )


# Helper function for authentication
def get_token_from_header(request: Request) -> str | None:
    """Extract Bearer token from Authorization header."""
    auth_header = request.headers.get("Authorization")
    if auth_header and auth_header.startswith("Bearer "):
        return auth_header[7:]  # Remove "Bearer " prefix
    return None


# Authentication endpoints
@app.post("/auth/register", response_model=AuthResponse, tags=["authentication"])
async def register(req: RegisterRequest) -> AuthResponse:
    """Register a new user.

    Args:
        req: Registration request with email, password, full_name

    Returns:
        Authentication response with token and user info

    Raises:
        HTTPException: If user already exists or registration fails
    """
    logger.info("registration_attempt", email=req.email)

    result = register_user(
        email=req.email,
        password=req.password,
        full_name=req.full_name,
    )

    if not result:
        raise HTTPException(
            status_code=400,
            detail="User with this email already exists"
        )

    logger.info("user_registered", email=req.email)
    return AuthResponse(**result)


@app.post("/auth/login", response_model=AuthResponse, tags=["authentication"])
async def login(req: LoginRequest) -> AuthResponse:
    """Login with email and password.

    Args:
        req: Login request with email and password

    Returns:
        Authentication response with token and user info

    Raises:
        HTTPException: If credentials are invalid
    """
    logger.info("login_attempt", email=req.email)

    result = authenticate_user(email=req.email, password=req.password)

    if not result:
        raise HTTPException(
            status_code=401,
            detail="Invalid email or password"
        )

    logger.info("user_logged_in", email=req.email)
    return AuthResponse(**result)


@app.post("/auth/logout", tags=["authentication"])
async def logout(request: Request) -> dict[str, str]:
    """Logout and revoke token.

    Returns:
        Success message
    """
    token = get_token_from_header(request)
    if token:
        revoke_token(token)
        logger.info("user_logged_out")

    return {"message": "Logged out successfully"}


@app.get("/auth/me", tags=["authentication"])
async def get_current_user_info(request: Request) -> dict[str, Any]:
    """Get current logged-in user information.

    Requires:
        Authorization: Bearer <token>

    Returns:
        User information

    Raises:
        HTTPException: If not authenticated
    """
    token = get_token_from_header(request)
    if not token:
        raise HTTPException(
            status_code=401,
            detail="Not authenticated. Please provide a valid Bearer token."
        )

    user = get_current_user(token)
    if not user:
        raise HTTPException(
            status_code=401,
            detail="Invalid or expired token"
        )

    return {
        "user_id": user["user_id"],
        "email": user["email"],
        "full_name": user.get("full_name"),
        "is_admin": user.get("is_admin", False),
        "total_queries": user.get("total_queries", 0),
        "total_cost_usd": user.get("total_cost_usd", 0.0),
    }


# Google OAuth endpoints
@app.get("/auth/google/login", tags=["authentication"])
async def google_oauth_login() -> RedirectResponse:
    """Initiate Google OAuth login.

    Redirects user to Google's login page.
    After successful authentication, Google redirects back to /auth/google/callback.
    """
    if not is_oauth_configured():
        raise HTTPException(
            status_code=501,
            detail="Google OAuth is not configured. Please set GOOGLE_CLIENT_ID and GOOGLE_CLIENT_SECRET."
        )

    # Generate random state for CSRF protection
    state = secrets.token_urlsafe(32)
    oauth_states[state] = time.time()

    # Clean up old states (older than 10 minutes)
    current_time = time.time()
    expired_states = [s for s, t in oauth_states.items() if current_time - t > 600]
    for s in expired_states:
        del oauth_states[s]

    # Get Google authorization URL
    auth_url = get_google_auth_url(state)

    logger.info("google_oauth_login_initiated", state=state)
    return RedirectResponse(url=auth_url)


@app.get("/auth/google/callback", tags=["authentication"])
async def google_oauth_callback(code: str, state: str):
    """Handle Google OAuth callback.

    This endpoint is called by Google after user authorizes the app.
    Returns HTML that saves the token and redirects to the home page.

    Args:
        code: Authorization code from Google
        state: State parameter for CSRF protection

    Returns:
        HTML response that handles token storage and redirect

    Raises:
        HTTPException: If OAuth is not configured, state is invalid, or authentication fails
    """
    if not is_oauth_configured():
        raise HTTPException(
            status_code=501,
            detail="Google OAuth is not configured"
        )

    # Verify state (CSRF protection)
    if state not in oauth_states:
        logger.warning("google_oauth_invalid_state", state=state)
        raise HTTPException(
            status_code=400,
            detail="Invalid state parameter. Please try logging in again."
        )

    # Remove used state
    del oauth_states[state]

    # Exchange code for access token
    access_token = exchange_code_for_token(code)
    if not access_token:
        logger.error("google_oauth_token_exchange_failed")
        raise HTTPException(
            status_code=400,
            detail="Failed to exchange authorization code for access token"
        )

    # Get user info from Google
    user_info = get_user_info(access_token)
    if not user_info or not user_info.get("email"):
        logger.error("google_oauth_user_info_failed")
        raise HTTPException(
            status_code=400,
            detail="Failed to get user information from Google"
        )

    # Authenticate or create user
    auth_data = authenticate_oauth(
        email=user_info["email"],
        oauth_provider="google",
        oauth_id=user_info["google_id"],
        full_name=user_info.get("name"),
    )

    logger.info("google_oauth_success", email=user_info["email"])

    # Return HTML that saves the token and redirects
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>Login Successful</title>
        <style>
            body {{
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
                display: flex;
                align-items: center;
                justify-content: center;
                min-height: 100vh;
                margin: 0;
                background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            }}
            .message {{
                background: white;
                padding: 40px;
                border-radius: 20px;
                box-shadow: 0 20px 60px rgba(0,0,0,0.3);
                text-align: center;
            }}
            h1 {{
                color: #667eea;
                margin-bottom: 10px;
            }}
            p {{
                color: #666;
            }}
            .spinner {{
                border: 3px solid #f3f3f3;
                border-top: 3px solid #667eea;
                border-radius: 50%;
                width: 40px;
                height: 40px;
                animation: spin 1s linear infinite;
                margin: 20px auto;
            }}
            @keyframes spin {{
                0% {{ transform: rotate(0deg); }}
                100% {{ transform: rotate(360deg); }}
            }}
        </style>
    </head>
    <body>
        <div class="message">
            <h1>Login Successful!</h1>
            <div class="spinner"></div>
            <p>Redirecting you to the app...</p>
        </div>
        <script>
            // Save token to localStorage (use same key as auth.js)
            localStorage.setItem('tech_stack_advisor_token', '{auth_data["token"]}');

            // Redirect to home page
            setTimeout(() => {{
                window.location.href = '/';
            }}, 1000);
        </script>
    </body>
    </html>
    """

    return HTMLResponse(content=html_content)


# Long-term memory endpoints
@app.get("/memory/similar", response_model=SimilarQueriesResponse, tags=["memory"])
async def get_similar_queries(
    request: Request,
    query: str,
    limit: int = 5,
) -> SimilarQueriesResponse:
    """Find similar past queries (long-term memory).

    Demonstrates semantic search over user's query history.

    Args:
        query: Query text to find similar queries
        limit: Maximum number of similar queries to return

    Requires:
        Authorization: Bearer <token>

    Returns:
        List of similar past queries with similarity scores
    """
    token = get_token_from_header(request)
    if not token:
        raise HTTPException(status_code=401, detail="Authentication required")

    user = get_current_user(token)
    if not user:
        raise HTTPException(status_code=401, detail="Invalid token")

    memory_store = get_user_memory_store()
    similar = memory_store.search_similar_queries(
        user_id=user["user_id"],
        query=query,
        limit=limit,
    )

    logger.info("similar_queries_searched", user_id=user["user_id"], found=len(similar))

    return SimilarQueriesResponse(queries=similar)


@app.post("/feedback", tags=["feedback"])
async def submit_feedback(
    request: Request,
    req: FeedbackRequest,
) -> dict[str, str]:
    """Submit feedback on a recommendation.

    Args:
        req: Feedback request with rating, comment, etc.

    Requires:
        Authorization: Bearer <token>

    Returns:
        Success message with feedback ID
    """
    token = get_token_from_header(request)
    if not token:
        raise HTTPException(status_code=401, detail="Authentication required")

    user = get_current_user(token)
    if not user:
        raise HTTPException(status_code=401, detail="Invalid token")

    memory_store = get_user_memory_store()
    feedback_id = memory_store.store_feedback(
        user_id=user["user_id"],
        query_id=req.query_id,
        rating=req.rating,
        comment=req.comment,
        feedback_type=req.feedback_type,
    )

    logger.info("feedback_submitted", user_id=user["user_id"], feedback_id=feedback_id)

    return {"message": "Feedback submitted successfully", "feedback_id": feedback_id}


# Admin dashboard endpoints
@app.get("/admin/stats", response_model=AdminStatsResponse, tags=["admin"])
async def get_admin_stats(request: Request) -> AdminStatsResponse:
    """Get admin dashboard statistics.

    Shows:
    - Total users and queries
    - Per-user usage and costs
    - Recent feedback

    Requires:
        Authorization: Bearer <token> (admin only)

    Returns:
        Admin statistics
    """
    token = get_token_from_header(request)
    if not token:
        raise HTTPException(status_code=401, detail="Authentication required")

    user = get_current_user(token)
    if not user or not user.get("is_admin"):
        raise HTTPException(status_code=403, detail="Admin access required")

    memory_store = get_user_memory_store()

    # Get all users
    users = memory_store.get_all_users(limit=100)

    # Get all queries
    queries = memory_store.get_all_queries(limit=200)

    # Get all feedback
    feedback = memory_store.get_all_feedback(limit=200)

    # Calculate totals
    total_cost = sum(u.get("total_cost_usd", 0.0) for u in users)
    total_queries_count = sum(u.get("total_queries", 0) for u in users)

    # Per-user stats
    user_stats = [
        {
            "user_id": u["user_id"],
            "email": u["email"],
            "total_queries": u.get("total_queries", 0),
            "total_cost_usd": u.get("total_cost_usd", 0.0),
            "created_at": u.get("created_at"),
            "last_login": u.get("last_login"),
        }
        for u in users
    ]

    logger.info("admin_stats_accessed", admin_user=user["email"])

    return AdminStatsResponse(
        total_users=len(users),
        total_queries=total_queries_count,
        total_cost_usd=total_cost,
        recent_queries=queries[:50],  # Last 50
        recent_feedback=feedback[:50],  # Last 50
        user_stats=user_stats,
    )


@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception) -> JSONResponse:
    """Global exception handler."""
    logger.error(
        "unhandled_exception",
        error=str(exc),
        path=request.url.path,
        method=request.method,
    )

    return JSONResponse(
        status_code=500,
        content={
            "status": "error",
            "error": "An unexpected error occurred",
            "detail": str(exc) if settings.environment == "development" else None,
        },
    )


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        app,
        host=settings.api_host,
        port=settings.api_port,
        reload=False,  # Disable reload in production to avoid import issues
        log_level=settings.log_level.lower(),
    )