Tech Stack Advisor - Code Viewer

← Back to File Tree

vectorstore.py

Language: python | Path: backend/src/rag/vectorstore.py | Lines: 225
"""Qdrant vector store wrapper."""
from typing import Any, List, Dict
from qdrant_client import QdrantClient
from qdrant_client.models import (
    Distance,
    VectorParams,
    PointStruct,
    Filter,
    FieldCondition,
    MatchValue,
)
from ..core.config import settings
from ..core.logging import get_logger
from .embeddings import get_embedding_model

logger = get_logger(__name__)


class VectorStore:
    """Qdrant vector store for RAG."""

    def __init__(
        self,
        collection_name: str = "tech_stack_knowledge",
        use_local: bool = False,
    ) -> None:
        """Initialize the vector store.

        Args:
            collection_name: Name of the Qdrant collection
            use_local: If True, use local in-memory Qdrant (for testing)
        """
        self.collection_name = collection_name
        self.embedding_model = get_embedding_model()

        if use_local:
            logger.info("initializing_local_vectorstore", collection=collection_name)
            self.client = QdrantClient(":memory:")
        else:
            logger.info(
                "initializing_qdrant_client",
                collection=collection_name,
                url=settings.qdrant_url,
            )
            self.client = QdrantClient(
                url=settings.qdrant_url,
                api_key=settings.qdrant_api_key,
            )

        self._ensure_collection()

    def _ensure_collection(self) -> None:
        """Ensure the collection exists, create if not."""
        try:
            # Check if collection exists
            collections = self.client.get_collections().collections
            exists = any(c.name == self.collection_name for c in collections)

            if not exists:
                logger.info("creating_collection", collection=self.collection_name)
                self.client.create_collection(
                    collection_name=self.collection_name,
                    vectors_config=VectorParams(
                        size=self.embedding_model.dimension,
                        distance=Distance.COSINE,
                    ),
                )
                logger.info("collection_created", collection=self.collection_name)
            else:
                logger.info("collection_exists", collection=self.collection_name)

        except Exception as e:
            logger.error("collection_error", error=str(e))
            raise

    def add_documents(
        self,
        documents: List[Dict[str, Any]],
        batch_size: int = 100,
    ) -> int:
        """Add documents to the vector store.

        Args:
            documents: List of documents with 'text' and 'metadata' fields
                Example: [
                    {
                        "text": "PostgreSQL is a relational database...",
                        "metadata": {
                            "category": "database",
                            "technology": "postgresql",
                            "source": "official_docs"
                        }
                    }
                ]
            batch_size: Batch size for uploading

        Returns:
            Number of documents added
        """
        logger.info("adding_documents", count=len(documents))

        # Extract texts
        texts = [doc["text"] for doc in documents]

        # Generate embeddings
        embeddings = self.embedding_model.embed_batch(texts, batch_size=batch_size)

        # Create points
        points = []
        for idx, (doc, embedding) in enumerate(zip(documents, embeddings)):
            point = PointStruct(
                id=idx,
                vector=embedding,
                payload={
                    "text": doc["text"],
                    **doc.get("metadata", {}),
                },
            )
            points.append(point)

        # Upload in batches
        for i in range(0, len(points), batch_size):
            batch = points[i : i + batch_size]
            self.client.upsert(
                collection_name=self.collection_name,
                points=batch,
            )
            logger.info("batch_uploaded", start=i, end=min(i + batch_size, len(points)))

        logger.info("documents_added", count=len(documents))
        return len(documents)

    def search(
        self,
        query: str,
        limit: int = 5,
        score_threshold: float = 0.0,
        filters: Dict[str, Any] | None = None,
    ) -> List[Dict[str, Any]]:
        """Search for similar documents.

        Args:
            query: Search query text
            limit: Maximum number of results
            score_threshold: Minimum similarity score (0-1)
            filters: Optional metadata filters
                Example: {"category": "database", "technology": "postgresql"}

        Returns:
            List of matching documents with scores
        """
        logger.info("searching", query=query[:100], limit=limit)

        # Generate query embedding
        query_embedding = self.embedding_model.embed_text(query)

        # Build filters
        query_filter = None
        if filters:
            conditions = [
                FieldCondition(key=key, match=MatchValue(value=value))
                for key, value in filters.items()
            ]
            query_filter = Filter(must=conditions)

        # Search (using query_points for compatibility)
        results = self.client.query_points(
            collection_name=self.collection_name,
            query=query_embedding,
            limit=limit,
            score_threshold=score_threshold,
            query_filter=query_filter,
        ).points

        # Format results
        formatted_results = []
        for result in results:
            formatted_results.append({
                "text": result.payload.get("text", ""),
                "score": result.score,
                "metadata": {
                    k: v for k, v in result.payload.items() if k != "text"
                },
            })

        logger.info("search_complete", results=len(formatted_results))
        return formatted_results

    def delete_collection(self) -> None:
        """Delete the collection (use with caution)."""
        logger.warning("deleting_collection", collection=self.collection_name)
        self.client.delete_collection(collection_name=self.collection_name)

    def get_collection_info(self) -> Dict[str, Any]:
        """Get collection statistics.

        Returns:
            Dictionary with collection information
        """
        info = self.client.get_collection(collection_name=self.collection_name)
        return {
            "name": self.collection_name,
            "vectors_count": info.points_count or 0,  # points_count is the correct attribute
            "points_count": info.points_count or 0,
            "status": info.status,
        }


# Global vector store instance (lazy loaded)
_vector_store: VectorStore | None = None


def get_vector_store(use_local: bool = False) -> VectorStore:
    """Get the global vector store instance.

    Args:
        use_local: If True, use local in-memory store (for testing)

    Returns:
        Singleton vector store instance
    """
    global _vector_store
    if _vector_store is None:
        _vector_store = VectorStore(use_local=use_local)
    return _vector_store