"""Qdrant vector store wrapper."""
from typing import Any, List, Dict
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
Filter,
FieldCondition,
MatchValue,
)
from ..core.config import settings
from ..core.logging import get_logger
from .embeddings import get_embedding_model
logger = get_logger(__name__)
class VectorStore:
"""Qdrant vector store for RAG."""
def __init__(
self,
collection_name: str = "tech_stack_knowledge",
use_local: bool = False,
) -> None:
"""Initialize the vector store.
Args:
collection_name: Name of the Qdrant collection
use_local: If True, use local in-memory Qdrant (for testing)
"""
self.collection_name = collection_name
self.embedding_model = get_embedding_model()
if use_local:
logger.info("initializing_local_vectorstore", collection=collection_name)
self.client = QdrantClient(":memory:")
else:
logger.info(
"initializing_qdrant_client",
collection=collection_name,
url=settings.qdrant_url,
)
self.client = QdrantClient(
url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
)
self._ensure_collection()
def _ensure_collection(self) -> None:
"""Ensure the collection exists, create if not."""
try:
# Check if collection exists
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if not exists:
logger.info("creating_collection", collection=self.collection_name)
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=self.embedding_model.dimension,
distance=Distance.COSINE,
),
)
logger.info("collection_created", collection=self.collection_name)
else:
logger.info("collection_exists", collection=self.collection_name)
except Exception as e:
logger.error("collection_error", error=str(e))
raise
def add_documents(
self,
documents: List[Dict[str, Any]],
batch_size: int = 100,
) -> int:
"""Add documents to the vector store.
Args:
documents: List of documents with 'text' and 'metadata' fields
Example: [
{
"text": "PostgreSQL is a relational database...",
"metadata": {
"category": "database",
"technology": "postgresql",
"source": "official_docs"
}
}
]
batch_size: Batch size for uploading
Returns:
Number of documents added
"""
logger.info("adding_documents", count=len(documents))
# Extract texts
texts = [doc["text"] for doc in documents]
# Generate embeddings
embeddings = self.embedding_model.embed_batch(texts, batch_size=batch_size)
# Create points
points = []
for idx, (doc, embedding) in enumerate(zip(documents, embeddings)):
point = PointStruct(
id=idx,
vector=embedding,
payload={
"text": doc["text"],
**doc.get("metadata", {}),
},
)
points.append(point)
# Upload in batches
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(
collection_name=self.collection_name,
points=batch,
)
logger.info("batch_uploaded", start=i, end=min(i + batch_size, len(points)))
logger.info("documents_added", count=len(documents))
return len(documents)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.0,
filters: Dict[str, Any] | None = None,
) -> List[Dict[str, Any]]:
"""Search for similar documents.
Args:
query: Search query text
limit: Maximum number of results
score_threshold: Minimum similarity score (0-1)
filters: Optional metadata filters
Example: {"category": "database", "technology": "postgresql"}
Returns:
List of matching documents with scores
"""
logger.info("searching", query=query[:100], limit=limit)
# Generate query embedding
query_embedding = self.embedding_model.embed_text(query)
# Build filters
query_filter = None
if filters:
conditions = [
FieldCondition(key=key, match=MatchValue(value=value))
for key, value in filters.items()
]
query_filter = Filter(must=conditions)
# Search (using query_points for compatibility)
results = self.client.query_points(
collection_name=self.collection_name,
query=query_embedding,
limit=limit,
score_threshold=score_threshold,
query_filter=query_filter,
).points
# Format results
formatted_results = []
for result in results:
formatted_results.append({
"text": result.payload.get("text", ""),
"score": result.score,
"metadata": {
k: v for k, v in result.payload.items() if k != "text"
},
})
logger.info("search_complete", results=len(formatted_results))
return formatted_results
def delete_collection(self) -> None:
"""Delete the collection (use with caution)."""
logger.warning("deleting_collection", collection=self.collection_name)
self.client.delete_collection(collection_name=self.collection_name)
def get_collection_info(self) -> Dict[str, Any]:
"""Get collection statistics.
Returns:
Dictionary with collection information
"""
info = self.client.get_collection(collection_name=self.collection_name)
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0, # points_count is the correct attribute
"points_count": info.points_count or 0,
"status": info.status,
}
# Global vector store instance (lazy loaded)
_vector_store: VectorStore | None = None
def get_vector_store(use_local: bool = False) -> VectorStore:
"""Get the global vector store instance.
Args:
use_local: If True, use local in-memory store (for testing)
Returns:
Singleton vector store instance
"""
global _vector_store
if _vector_store is None:
_vector_store = VectorStore(use_local=use_local)
return _vector_store