Tech Stack Advisor - Code Viewer

← Back to File Tree

workflow.py

Language: python | Path: backend/src/orchestration/workflow.py | Lines: 395
"""LangGraph workflow for orchestrating tech stack recommendation agents."""
import re
import uuid
from typing import Any
from langgraph.graph import StateGraph, END
from ..agents import DatabaseAgent, InfrastructureAgent, CostAgent, SecurityAgent
from ..core.logging import get_logger
from .state import WorkflowState

logger = get_logger(__name__)


class TechStackOrchestrator:
    """Orchestrates multiple agents using LangGraph to provide tech stack recommendations."""

    def __init__(self) -> None:
        """Initialize the orchestrator with all agents and build the workflow graph."""
        self.database_agent = DatabaseAgent()
        self.infrastructure_agent = InfrastructureAgent()
        self.cost_agent = CostAgent()
        self.security_agent = SecurityAgent()

        # Build the workflow graph
        self.workflow = self._build_graph()
        logger.info("orchestrator_initialized", agents=4)

    def _build_graph(self) -> Any:
        """Build the LangGraph workflow.

        Returns:
            Compiled workflow graph
        """
        # Create the state graph
        graph = StateGraph(WorkflowState)

        # Add nodes for each agent
        graph.add_node("parse_query", self._parse_query_node)
        graph.add_node("database_agent", self._database_node)
        graph.add_node("infrastructure_agent", self._infrastructure_node)
        graph.add_node("cost_agent", self._cost_node)
        graph.add_node("security_agent", self._security_node)
        graph.add_node("synthesize", self._synthesize_node)

        # Define the workflow flow
        graph.set_entry_point("parse_query")

        # Sequential flow: parse -> database -> infrastructure -> cost -> security -> synthesize
        graph.add_edge("parse_query", "database_agent")
        graph.add_edge("database_agent", "infrastructure_agent")
        graph.add_edge("infrastructure_agent", "cost_agent")
        graph.add_edge("cost_agent", "security_agent")
        graph.add_edge("security_agent", "synthesize")
        graph.add_edge("synthesize", END)

        # Compile the graph
        return graph.compile()

    def _parse_query_node(self, state: WorkflowState) -> WorkflowState:
        """Parse user query and extract context.

        Args:
            state: Current workflow state

        Returns:
            Updated state with parsed context
        """
        logger.info("parse_query_start", query=state.get("user_query", ""))

        query = state.get("user_query", "").lower()

        # Check for user-provided DAU override first
        dau_override = state.get("dau_override")
        if dau_override:
            dau = dau_override
        else:
            # Extract DAU (daily active users) from query
            dau_match = re.search(r'(\d+(?:,\d+)*)\s*(?:k|thousand)?(?:\s+daily\s+active\s+users|dau)', query)
            if dau_match:
                dau_str = dau_match.group(1).replace(',', '')
                dau = int(dau_str)
                if 'k' in query[dau_match.start():dau_match.end()] or 'thousand' in query[dau_match.start():dau_match.end()]:
                    dau *= 1000
            else:
                # Default based on scale indicators
                if 'small' in query or 'startup' in query:
                    dau = 1_000
                elif 'medium' in query or 'growing' in query:
                    dau = 50_000
                elif 'large' in query or 'enterprise' in query or 'million' in query:
                    dau = 500_000
                else:
                    dau = 10_000  # Default

        # Estimate QPS based on DAU (rough heuristic: 1-2% concurrent, 10 req/min)
        qps = max(10, int(dau * 0.015 / 60))
        rps = qps

        # Data volume estimate (GB)
        data_gb = max(10, int(dau / 100))

        # Detect data type
        if 'unstructured' in query or 'document' in query or 'json' in query:
            data_type = 'unstructured'
        elif 'time-series' in query or 'metrics' in query or 'logs' in query:
            data_type = 'time-series'
        else:
            data_type = 'structured'

        # Detect consistency requirements
        consistency = 'eventual' if 'eventual' in query else 'strong'

        # Detect workload type
        if 'real-time' in query or 'chat' in query or 'websocket' in query:
            workload_type = 'realtime'
        elif 'api' in query or 'rest' in query or 'graphql' in query:
            workload_type = 'api'
        elif 'background' in query or 'batch' in query or 'cron' in query:
            workload_type = 'background'
        else:
            workload_type = 'web'

        # Budget detection
        budget_match = re.search(r'\$?(\d+(?:,\d+)*)\s*(?:/month|per month|monthly)?', query)
        budget_target = float(budget_match.group(1).replace(',', '')) if budget_match else 0.0
        budget_conscious = 'cheap' in query or 'low cost' in query or 'budget' in query or budget_target > 0

        # Data sensitivity
        if 'healthcare' in query or 'medical' in query or 'hipaa' in query:
            data_sensitivity = 'critical'
        elif 'payment' in query or 'financial' in query or 'banking' in query:
            data_sensitivity = 'critical'
        elif 'sensitive' in query or 'private' in query:
            data_sensitivity = 'high'
        else:
            data_sensitivity = 'medium'

        # Compliance requirements
        compliance_required = []
        if 'gdpr' in query or 'europe' in query or 'eu' in query:
            compliance_required.append('gdpr')
        if 'hipaa' in query or 'healthcare' in query or 'medical' in query:
            compliance_required.append('hipaa')
        if 'pci' in query or 'payment' in query or 'credit card' in query:
            compliance_required.append('pci_dss')
        if 'soc2' in query or 'soc 2' in query:
            compliance_required.append('soc2')

        # Public facing
        public_facing = not ('internal' in query or 'intranet' in query)

        # Generate correlation ID if not present
        correlation_id = state.get("correlation_id", str(uuid.uuid4()))

        logger.info(
            "parse_query_complete",
            correlation_id=correlation_id,
            dau=dau,
            qps=qps,
            data_type=data_type,
            workload_type=workload_type,
        )

        # Update state
        state.update({
            "correlation_id": correlation_id,
            "dau": dau,
            "qps": qps,
            "rps": rps,
            "data_gb": data_gb,
            "data_type": data_type,
            "consistency": consistency,
            "workload_type": workload_type,
            "budget_target": budget_target,
            "budget_conscious": budget_conscious,
            "data_sensitivity": data_sensitivity,
            "compliance_required": compliance_required,
            "public_facing": public_facing,
            "existing_stack": "none",
            "architecture": "monolith",  # Will be updated by infrastructure agent
        })

        return state

    async def _database_node(self, state: WorkflowState) -> WorkflowState:
        """Execute database agent.

        Args:
            state: Current workflow state

        Returns:
            Updated state with database recommendations
        """
        logger.info("database_agent_start", correlation_id=state.get("correlation_id"))

        try:
            result = await self.database_agent.analyze({
                "user_query": state.get("user_query", ""),
                "dau": state.get("dau", 0),
                "qps": state.get("qps", 0),
                "data_gb": state.get("data_gb", 10),
                "data_type": state.get("data_type", "structured"),
                "consistency": state.get("consistency", "strong"),
                "api_key": state.get("api_key"),
            })

            state["database_result"] = result
            logger.info("database_agent_complete", correlation_id=state.get("correlation_id"))

        except Exception as e:
            logger.error("database_agent_error", error=str(e), correlation_id=state.get("correlation_id"))
            state["error"] = f"Database agent error: {str(e)}"

        return state

    async def _infrastructure_node(self, state: WorkflowState) -> WorkflowState:
        """Execute infrastructure agent.

        Args:
            state: Current workflow state

        Returns:
            Updated state with infrastructure recommendations
        """
        logger.info("infrastructure_agent_start", correlation_id=state.get("correlation_id"))

        try:
            result = await self.infrastructure_agent.analyze({
                "user_query": state.get("user_query", ""),
                "dau": state.get("dau", 0),
                "rps": state.get("rps", 0),
                "workload_type": state.get("workload_type", "web"),
                "budget_conscious": state.get("budget_conscious", True),
                "existing_stack": state.get("existing_stack", "none"),
                "api_key": state.get("api_key"),
            })

            state["infrastructure_result"] = result

            # Extract architecture for security agent
            if result and "scale_info" in result:
                state["architecture"] = result["scale_info"].get("suggested_architecture", "monolith")

            logger.info("infrastructure_agent_complete", correlation_id=state.get("correlation_id"))

        except Exception as e:
            logger.error("infrastructure_agent_error", error=str(e), correlation_id=state.get("correlation_id"))
            state["error"] = f"Infrastructure agent error: {str(e)}"

        return state

    async def _cost_node(self, state: WorkflowState) -> WorkflowState:
        """Execute cost agent.

        Args:
            state: Current workflow state

        Returns:
            Updated state with cost estimates
        """
        logger.info("cost_agent_start", correlation_id=state.get("correlation_id"))

        try:
            result = await self.cost_agent.analyze({
                "user_query": state.get("user_query", ""),
                "dau": state.get("dau", 0),
                "budget_target": state.get("budget_target", 0),
                "infrastructure_recommendations": state.get("infrastructure_result", {}),
                "database_recommendations": state.get("database_result", {}),
                "api_key": state.get("api_key"),
            })

            state["cost_result"] = result
            logger.info("cost_agent_complete", correlation_id=state.get("correlation_id"))

        except Exception as e:
            logger.error("cost_agent_error", error=str(e), correlation_id=state.get("correlation_id"))
            state["error"] = f"Cost agent error: {str(e)}"

        return state

    async def _security_node(self, state: WorkflowState) -> WorkflowState:
        """Execute security agent.

        Args:
            state: Current workflow state

        Returns:
            Updated state with security recommendations
        """
        logger.info("security_agent_start", correlation_id=state.get("correlation_id"))

        try:
            result = await self.security_agent.analyze({
                "user_query": state.get("user_query", ""),
                "architecture": state.get("architecture", "monolith"),
                "data_sensitivity": state.get("data_sensitivity", "medium"),
                "compliance_required": state.get("compliance_required", []),
                "public_facing": state.get("public_facing", True),
                "api_key": state.get("api_key"),
            })

            state["security_result"] = result
            logger.info("security_agent_complete", correlation_id=state.get("correlation_id"))

        except Exception as e:
            logger.error("security_agent_error", error=str(e), correlation_id=state.get("correlation_id"))
            state["error"] = f"Security agent error: {str(e)}"

        return state

    def _synthesize_node(self, state: WorkflowState) -> WorkflowState:
        """Synthesize all agent results into final recommendation.

        Args:
            state: Current workflow state

        Returns:
            Updated state with final recommendation
        """
        logger.info("synthesize_start", correlation_id=state.get("correlation_id"))

        # Check for errors
        if state.get("error"):
            state["final_recommendation"] = {
                "status": "error",
                "error": state["error"],
            }
            return state

        # Build final recommendation
        final = {
            "status": "success",
            "query": state.get("user_query", ""),
            "correlation_id": state.get("correlation_id", ""),
            "parsed_context": {
                "dau": state.get("dau", 0),
                "qps": state.get("qps", 0),
                "data_type": state.get("data_type", ""),
                "workload_type": state.get("workload_type", ""),
                "data_sensitivity": state.get("data_sensitivity", ""),
                "compliance": state.get("compliance_required", []),
            },
            "recommendations": {
                "database": state.get("database_result", {}),
                "infrastructure": state.get("infrastructure_result", {}),
                "cost": state.get("cost_result", {}),
                "security": state.get("security_result", {}),
            },
        }

        state["final_recommendation"] = final
        logger.info("synthesize_complete", correlation_id=state.get("correlation_id"))

        return state

    async def process_query(self, user_query: str, dau_override: int | None = None, api_key: str | None = None) -> dict[str, Any]:
        """Process a user query through the full workflow.

        Args:
            user_query: User's tech stack question
            dau_override: Optional DAU override from user input
            api_key: Optional user-provided Anthropic API key

        Returns:
            Final recommendation dictionary
        """
        logger.info("workflow_start", query=user_query, dau_override=dau_override, using_custom_key=api_key is not None)

        # Initialize state
        initial_state: WorkflowState = {
            "user_query": user_query,
            "correlation_id": str(uuid.uuid4()),
            "dau_override": dau_override,
            "api_key": api_key,
        }

        try:
            # Run the workflow
            final_state = await self.workflow.ainvoke(initial_state)

            logger.info(
                "workflow_complete",
                correlation_id=final_state.get("correlation_id"),
                status=final_state.get("final_recommendation", {}).get("status", "unknown"),
            )

            return final_state.get("final_recommendation", {})

        except Exception as e:
            logger.error("workflow_error", error=str(e))
            return {
                "status": "error",
                "error": str(e),
                "query": user_query,
            }