Tech Stack Advisor - Code Viewer

← Back to File Tree

user_memory.py

Language: python | Path: backend/src/core/user_memory.py | Lines: 522
"""User memory and authentication using Qdrant for long-term storage."""
import time
import uuid
import hashlib
from typing import Any, Dict, List, Optional
from datetime import datetime, timedelta

from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    VectorParams,
    PointStruct,
    Filter,
    FieldCondition,
    MatchValue,
)

from .config import settings
from .logging import get_logger

logger = get_logger(__name__)


class UserMemoryStore:
    """Manages users, query history, and feedback in Qdrant."""

    # Fixed vector dimension for Railway free tier (no embeddings loaded)
    VECTOR_DIM = 384  # Standard dimension for all-MiniLM-L6-v2

    def __init__(self, use_local: bool = False):
        """Initialize the user memory store.

        Args:
            use_local: If True, use local in-memory Qdrant (for testing)
        """
        # Load embedding model for semantic search
        try:
            from sentence_transformers import SentenceTransformer
            logger.info("loading_embedding_model", model="all-MiniLM-L6-v2")
            self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
            logger.info("embedding_model_loaded")
        except Exception as e:
            logger.error("embedding_model_load_failed", error=str(e))
            self.embedding_model = None

        if use_local:
            logger.info("initializing_local_user_memory")
            self.client = QdrantClient(":memory:")
        else:
            logger.info("initializing_qdrant_user_memory", url=settings.qdrant_url)
            self.client = QdrantClient(
                url=settings.qdrant_url,
                api_key=settings.qdrant_api_key,
            )

        self._ensure_collections()

    def _ensure_collections(self) -> None:
        """Ensure all required collections exist."""
        collections = self.client.get_collections().collections
        existing = {c.name for c in collections}

        # Users collection - no vectors, just metadata storage
        if "users" not in existing:
            logger.info("creating_users_collection")
            self.client.create_collection(
                collection_name="users",
                vectors_config=VectorParams(
                    size=self.VECTOR_DIM,
                    distance=Distance.COSINE,
                ),
            )
            # Create indexes for email and user_id to enable filtering
            self.client.create_payload_index(
                collection_name="users",
                field_name="email",
                field_schema="keyword",
            )
            self.client.create_payload_index(
                collection_name="users",
                field_name="user_id",
                field_schema="keyword",
            )

        # User queries collection - metadata only (no semantic search on Railway free tier)
        if "user_queries" not in existing:
            logger.info("creating_user_queries_collection")
            self.client.create_collection(
                collection_name="user_queries",
                vectors_config=VectorParams(
                    size=self.VECTOR_DIM,
                    distance=Distance.COSINE,
                ),
            )
            # Create indexes for user_id and query_id
            self.client.create_payload_index(
                collection_name="user_queries",
                field_name="user_id",
                field_schema="keyword",
            )
            self.client.create_payload_index(
                collection_name="user_queries",
                field_name="query_id",
                field_schema="keyword",
            )

        # Feedback collection - no vectors needed
        if "user_feedback" not in existing:
            logger.info("creating_user_feedback_collection")
            self.client.create_collection(
                collection_name="user_feedback",
                vectors_config=VectorParams(
                    size=self.VECTOR_DIM,
                    distance=Distance.COSINE,
                ),
            )
            # Create indexes for user_id and feedback_id
            self.client.create_payload_index(
                collection_name="user_feedback",
                field_name="user_id",
                field_schema="keyword",
            )
            self.client.create_payload_index(
                collection_name="user_feedback",
                field_name="feedback_id",
                field_schema="keyword",
            )

        logger.info("user_memory_collections_ready")

    # ==================== USER MANAGEMENT ====================

    def create_user(
        self,
        email: str,
        password: Optional[str] = None,
        full_name: Optional[str] = None,
        oauth_provider: Optional[str] = None,
        oauth_id: Optional[str] = None,
        is_admin: bool = False,
    ) -> Dict[str, Any]:
        """Create a new user.

        Args:
            email: User email (unique identifier)
            password: Plain text password (will be hashed)
            full_name: User's full name
            oauth_provider: OAuth provider (google, microsoft, linkedin, facebook)
            oauth_id: OAuth provider user ID
            is_admin: Whether user is admin

        Returns:
            User data dictionary
        """
        # Check if user already exists
        existing = self.get_user_by_email(email)
        if existing:
            raise ValueError(f"User with email {email} already exists")

        user_id = str(uuid.uuid4())

        # Hash password if provided
        hashed_password = None
        if password:
            hashed_password = hashlib.sha256(password.encode()).hexdigest()

        user_data = {
            "user_id": user_id,
            "email": email,
            "hashed_password": hashed_password,
            "full_name": full_name,
            "is_admin": is_admin,
            "is_active": True,
            "oauth_provider": oauth_provider,
            "oauth_id": oauth_id,
            "created_at": time.time(),
            "last_login": None,
            "total_queries": 0,
            "total_cost_usd": 0.0,
        }

        # Store in Qdrant (using a dummy vector since we just need the payload)
        dummy_vector = [0.0] * self.VECTOR_DIM
        point = PointStruct(
            id=abs(hash(user_id)) % (10**10),  # Convert UUID to int
            vector=dummy_vector,
            payload=user_data,
        )

        self.client.upsert(
            collection_name="users",
            points=[point],
        )

        logger.info("user_created", user_id=user_id, email=email, oauth=oauth_provider is not None)
        return user_data

    def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
        """Get user by email."""
        try:
            results = self.client.scroll(
                collection_name="users",
                scroll_filter=Filter(
                    must=[FieldCondition(key="email", match=MatchValue(value=email))]
                ),
                limit=1,
            )

            if results[0]:  # results is a tuple (points, next_offset)
                return dict(results[0][0].payload)
            return None
        except Exception as e:
            logger.error("get_user_error", email=email, error=str(e))
            return None

    def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
        """Get user by ID."""
        try:
            results = self.client.scroll(
                collection_name="users",
                scroll_filter=Filter(
                    must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
                ),
                limit=1,
            )

            if results[0]:
                return dict(results[0][0].payload)
            return None
        except Exception as e:
            logger.error("get_user_by_id_error", user_id=user_id, error=str(e))
            return None

    def verify_password(self, email: str, password: str) -> Optional[Dict[str, Any]]:
        """Verify user password and return user data if valid."""
        user = self.get_user_by_email(email)
        if not user or not user.get("hashed_password"):
            return None

        hashed_input = hashlib.sha256(password.encode()).hexdigest()
        if hashed_input == user["hashed_password"]:
            # Update last login
            self.update_user(user["user_id"], {"last_login": time.time()})
            return user
        return None

    def update_user(self, user_id: str, updates: Dict[str, Any]) -> bool:
        """Update user data."""
        user = self.get_user_by_id(user_id)
        if not user:
            return False

        # Merge updates
        user.update(updates)

        # Update in Qdrant
        dummy_vector = [0.0] * self.VECTOR_DIM
        point = PointStruct(
            id=abs(hash(user_id)) % (10**10),
            vector=dummy_vector,
            payload=user,
        )

        self.client.upsert(
            collection_name="users",
            points=[point],
        )

        logger.info("user_updated", user_id=user_id)
        return True

    def get_all_users(self, limit: int = 100) -> List[Dict[str, Any]]:
        """Get all users (for admin dashboard)."""
        try:
            results = self.client.scroll(
                collection_name="users",
                limit=limit,
            )
            return [dict(point.payload) for point in results[0]]
        except Exception as e:
            logger.error("get_all_users_error", error=str(e))
            return []

    # ==================== QUERY HISTORY ====================

    def store_query(
        self,
        user_id: str,
        query: str,
        correlation_id: str,
        dau: Optional[int],
        parsed_context: Optional[Dict[str, Any]],
        recommendations: Optional[Dict[str, Any]],
        status: str,
        error_message: Optional[str],
        tokens_used: int,
        cost_usd: float,
    ) -> str:
        """Store a user query with semantic embedding for similarity search.

        Returns:
            Query ID
        """
        query_id = str(uuid.uuid4())

        # Generate embedding if model is available, otherwise use dummy vector
        if self.embedding_model is not None:
            query_embedding = self.embedding_model.encode(query).tolist()
        else:
            query_embedding = [0.0] * self.VECTOR_DIM
            logger.warning("store_query_without_embedding", user_id=user_id)

        query_data = {
            "query_id": query_id,
            "user_id": user_id,
            "query": query,
            "correlation_id": correlation_id,
            "dau": dau,
            "parsed_context": parsed_context,
            "recommendations": recommendations,
            "status": status,
            "error_message": error_message,
            "tokens_used": tokens_used,
            "cost_usd": cost_usd,
            "created_at": time.time(),
        }

        point = PointStruct(
            id=abs(hash(query_id)) % (10**10),
            vector=query_embedding,
            payload=query_data,
        )

        self.client.upsert(
            collection_name="user_queries",
            points=[point],
        )

        # Update user stats
        user = self.get_user_by_id(user_id)
        if user:
            self.update_user(user_id, {
                "total_queries": user.get("total_queries", 0) + 1,
                "total_cost_usd": user.get("total_cost_usd", 0.0) + cost_usd,
            })

        logger.info("query_stored", user_id=user_id, query_id=query_id, cost=cost_usd)
        return query_id

    def search_similar_queries(
        self,
        user_id: str,
        query: str,
        limit: int = 5,
    ) -> List[Dict[str, Any]]:
        """Search for similar queries by the same user using semantic search.

        Args:
            user_id: User ID to filter queries
            query: Query text to find similar queries for
            limit: Maximum number of similar queries to return

        Returns:
            List of similar queries with similarity scores
        """
        if self.embedding_model is None:
            logger.warning(
                "semantic_search_disabled",
                user_id=user_id,
                reason="Embedding model not loaded"
            )
            return []

        try:
            # Generate embedding for the search query
            query_embedding = self.embedding_model.encode(query).tolist()

            # Search for similar queries
            results = self.client.search(
                collection_name="user_queries",
                query_vector=query_embedding,
                query_filter=Filter(
                    must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
                ),
                limit=limit,
            )

            # Convert results to list of dicts with similarity scores
            similar_queries = []
            for result in results:
                query_data = dict(result.payload)
                query_data["similarity_score"] = result.score
                similar_queries.append(query_data)

            logger.info(
                "similar_queries_found",
                user_id=user_id,
                count=len(similar_queries)
            )

            return similar_queries

        except Exception as e:
            logger.error("search_similar_queries_error", user_id=user_id, error=str(e))
            return []

    def get_user_queries(
        self,
        user_id: str,
        limit: int = 50,
    ) -> List[Dict[str, Any]]:
        """Get recent queries by a user."""
        try:
            results = self.client.scroll(
                collection_name="user_queries",
                scroll_filter=Filter(
                    must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
                ),
                limit=limit,
            )

            queries = [dict(point.payload) for point in results[0]]
            # Sort by created_at descending
            queries.sort(key=lambda x: x.get("created_at", 0), reverse=True)
            return queries
        except Exception as e:
            logger.error("get_user_queries_error", user_id=user_id, error=str(e))
            return []

    def get_all_queries(self, limit: int = 200) -> List[Dict[str, Any]]:
        """Get all queries (for admin dashboard)."""
        try:
            results = self.client.scroll(
                collection_name="user_queries",
                limit=limit,
            )
            queries = [dict(point.payload) for point in results[0]]
            queries.sort(key=lambda x: x.get("created_at", 0), reverse=True)
            return queries
        except Exception as e:
            logger.error("get_all_queries_error", error=str(e))
            return []

    # ==================== FEEDBACK ====================

    def store_feedback(
        self,
        user_id: str,
        query_id: Optional[str],
        rating: Optional[int],
        comment: Optional[str],
        feedback_type: Optional[str],
    ) -> str:
        """Store user feedback."""
        feedback_id = str(uuid.uuid4())

        feedback_data = {
            "feedback_id": feedback_id,
            "user_id": user_id,
            "query_id": query_id,
            "rating": rating,
            "comment": comment,
            "feedback_type": feedback_type,
            "created_at": time.time(),
        }

        # Use dummy vector
        dummy_vector = [0.0] * self.VECTOR_DIM
        point = PointStruct(
            id=abs(hash(feedback_id)) % (10**10),
            vector=dummy_vector,
            payload=feedback_data,
        )

        self.client.upsert(
            collection_name="user_feedback",
            points=[point],
        )

        logger.info("feedback_stored", user_id=user_id, feedback_id=feedback_id, rating=rating)
        return feedback_id

    def get_user_feedback(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
        """Get feedback by a user."""
        try:
            results = self.client.scroll(
                collection_name="user_feedback",
                scroll_filter=Filter(
                    must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
                ),
                limit=limit,
            )
            return [dict(point.payload) for point in results[0]]
        except Exception as e:
            logger.error("get_user_feedback_error", user_id=user_id, error=str(e))
            return []

    def get_all_feedback(self, limit: int = 200) -> List[Dict[str, Any]]:
        """Get all feedback (for admin dashboard)."""
        try:
            results = self.client.scroll(
                collection_name="user_feedback",
                limit=limit,
            )
            feedback = [dict(point.payload) for point in results[0]]
            feedback.sort(key=lambda x: x.get("created_at", 0), reverse=True)
            return feedback
        except Exception as e:
            logger.error("get_all_feedback_error", error=str(e))
            return []


# Global instance
_user_memory_store: Optional[UserMemoryStore] = None


def get_user_memory_store(use_local: bool = False) -> UserMemoryStore:
    """Get the global user memory store instance."""
    global _user_memory_store
    if _user_memory_store is None:
        _user_memory_store = UserMemoryStore(use_local=use_local)
    return _user_memory_store