"""User memory and authentication using Qdrant for long-term storage."""
import time
import uuid
import hashlib
from typing import Any, Dict, List, Optional
from datetime import datetime, timedelta
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PointStruct,
Filter,
FieldCondition,
MatchValue,
)
from .config import settings
from .logging import get_logger
logger = get_logger(__name__)
class UserMemoryStore:
"""Manages users, query history, and feedback in Qdrant."""
# Fixed vector dimension for Railway free tier (no embeddings loaded)
VECTOR_DIM = 384 # Standard dimension for all-MiniLM-L6-v2
def __init__(self, use_local: bool = False):
"""Initialize the user memory store.
Args:
use_local: If True, use local in-memory Qdrant (for testing)
"""
# Load embedding model for semantic search
try:
from sentence_transformers import SentenceTransformer
logger.info("loading_embedding_model", model="all-MiniLM-L6-v2")
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("embedding_model_loaded")
except Exception as e:
logger.error("embedding_model_load_failed", error=str(e))
self.embedding_model = None
if use_local:
logger.info("initializing_local_user_memory")
self.client = QdrantClient(":memory:")
else:
logger.info("initializing_qdrant_user_memory", url=settings.qdrant_url)
self.client = QdrantClient(
url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
)
self._ensure_collections()
def _ensure_collections(self) -> None:
"""Ensure all required collections exist."""
collections = self.client.get_collections().collections
existing = {c.name for c in collections}
# Users collection - no vectors, just metadata storage
if "users" not in existing:
logger.info("creating_users_collection")
self.client.create_collection(
collection_name="users",
vectors_config=VectorParams(
size=self.VECTOR_DIM,
distance=Distance.COSINE,
),
)
# Create indexes for email and user_id to enable filtering
self.client.create_payload_index(
collection_name="users",
field_name="email",
field_schema="keyword",
)
self.client.create_payload_index(
collection_name="users",
field_name="user_id",
field_schema="keyword",
)
# User queries collection - metadata only (no semantic search on Railway free tier)
if "user_queries" not in existing:
logger.info("creating_user_queries_collection")
self.client.create_collection(
collection_name="user_queries",
vectors_config=VectorParams(
size=self.VECTOR_DIM,
distance=Distance.COSINE,
),
)
# Create indexes for user_id and query_id
self.client.create_payload_index(
collection_name="user_queries",
field_name="user_id",
field_schema="keyword",
)
self.client.create_payload_index(
collection_name="user_queries",
field_name="query_id",
field_schema="keyword",
)
# Feedback collection - no vectors needed
if "user_feedback" not in existing:
logger.info("creating_user_feedback_collection")
self.client.create_collection(
collection_name="user_feedback",
vectors_config=VectorParams(
size=self.VECTOR_DIM,
distance=Distance.COSINE,
),
)
# Create indexes for user_id and feedback_id
self.client.create_payload_index(
collection_name="user_feedback",
field_name="user_id",
field_schema="keyword",
)
self.client.create_payload_index(
collection_name="user_feedback",
field_name="feedback_id",
field_schema="keyword",
)
logger.info("user_memory_collections_ready")
# ==================== USER MANAGEMENT ====================
def create_user(
self,
email: str,
password: Optional[str] = None,
full_name: Optional[str] = None,
oauth_provider: Optional[str] = None,
oauth_id: Optional[str] = None,
is_admin: bool = False,
) -> Dict[str, Any]:
"""Create a new user.
Args:
email: User email (unique identifier)
password: Plain text password (will be hashed)
full_name: User's full name
oauth_provider: OAuth provider (google, microsoft, linkedin, facebook)
oauth_id: OAuth provider user ID
is_admin: Whether user is admin
Returns:
User data dictionary
"""
# Check if user already exists
existing = self.get_user_by_email(email)
if existing:
raise ValueError(f"User with email {email} already exists")
user_id = str(uuid.uuid4())
# Hash password if provided
hashed_password = None
if password:
hashed_password = hashlib.sha256(password.encode()).hexdigest()
user_data = {
"user_id": user_id,
"email": email,
"hashed_password": hashed_password,
"full_name": full_name,
"is_admin": is_admin,
"is_active": True,
"oauth_provider": oauth_provider,
"oauth_id": oauth_id,
"created_at": time.time(),
"last_login": None,
"total_queries": 0,
"total_cost_usd": 0.0,
}
# Store in Qdrant (using a dummy vector since we just need the payload)
dummy_vector = [0.0] * self.VECTOR_DIM
point = PointStruct(
id=abs(hash(user_id)) % (10**10), # Convert UUID to int
vector=dummy_vector,
payload=user_data,
)
self.client.upsert(
collection_name="users",
points=[point],
)
logger.info("user_created", user_id=user_id, email=email, oauth=oauth_provider is not None)
return user_data
def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]:
"""Get user by email."""
try:
results = self.client.scroll(
collection_name="users",
scroll_filter=Filter(
must=[FieldCondition(key="email", match=MatchValue(value=email))]
),
limit=1,
)
if results[0]: # results is a tuple (points, next_offset)
return dict(results[0][0].payload)
return None
except Exception as e:
logger.error("get_user_error", email=email, error=str(e))
return None
def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user by ID."""
try:
results = self.client.scroll(
collection_name="users",
scroll_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=1,
)
if results[0]:
return dict(results[0][0].payload)
return None
except Exception as e:
logger.error("get_user_by_id_error", user_id=user_id, error=str(e))
return None
def verify_password(self, email: str, password: str) -> Optional[Dict[str, Any]]:
"""Verify user password and return user data if valid."""
user = self.get_user_by_email(email)
if not user or not user.get("hashed_password"):
return None
hashed_input = hashlib.sha256(password.encode()).hexdigest()
if hashed_input == user["hashed_password"]:
# Update last login
self.update_user(user["user_id"], {"last_login": time.time()})
return user
return None
def update_user(self, user_id: str, updates: Dict[str, Any]) -> bool:
"""Update user data."""
user = self.get_user_by_id(user_id)
if not user:
return False
# Merge updates
user.update(updates)
# Update in Qdrant
dummy_vector = [0.0] * self.VECTOR_DIM
point = PointStruct(
id=abs(hash(user_id)) % (10**10),
vector=dummy_vector,
payload=user,
)
self.client.upsert(
collection_name="users",
points=[point],
)
logger.info("user_updated", user_id=user_id)
return True
def get_all_users(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get all users (for admin dashboard)."""
try:
results = self.client.scroll(
collection_name="users",
limit=limit,
)
return [dict(point.payload) for point in results[0]]
except Exception as e:
logger.error("get_all_users_error", error=str(e))
return []
# ==================== QUERY HISTORY ====================
def store_query(
self,
user_id: str,
query: str,
correlation_id: str,
dau: Optional[int],
parsed_context: Optional[Dict[str, Any]],
recommendations: Optional[Dict[str, Any]],
status: str,
error_message: Optional[str],
tokens_used: int,
cost_usd: float,
) -> str:
"""Store a user query with semantic embedding for similarity search.
Returns:
Query ID
"""
query_id = str(uuid.uuid4())
# Generate embedding if model is available, otherwise use dummy vector
if self.embedding_model is not None:
query_embedding = self.embedding_model.encode(query).tolist()
else:
query_embedding = [0.0] * self.VECTOR_DIM
logger.warning("store_query_without_embedding", user_id=user_id)
query_data = {
"query_id": query_id,
"user_id": user_id,
"query": query,
"correlation_id": correlation_id,
"dau": dau,
"parsed_context": parsed_context,
"recommendations": recommendations,
"status": status,
"error_message": error_message,
"tokens_used": tokens_used,
"cost_usd": cost_usd,
"created_at": time.time(),
}
point = PointStruct(
id=abs(hash(query_id)) % (10**10),
vector=query_embedding,
payload=query_data,
)
self.client.upsert(
collection_name="user_queries",
points=[point],
)
# Update user stats
user = self.get_user_by_id(user_id)
if user:
self.update_user(user_id, {
"total_queries": user.get("total_queries", 0) + 1,
"total_cost_usd": user.get("total_cost_usd", 0.0) + cost_usd,
})
logger.info("query_stored", user_id=user_id, query_id=query_id, cost=cost_usd)
return query_id
def search_similar_queries(
self,
user_id: str,
query: str,
limit: int = 5,
) -> List[Dict[str, Any]]:
"""Search for similar queries by the same user using semantic search.
Args:
user_id: User ID to filter queries
query: Query text to find similar queries for
limit: Maximum number of similar queries to return
Returns:
List of similar queries with similarity scores
"""
if self.embedding_model is None:
logger.warning(
"semantic_search_disabled",
user_id=user_id,
reason="Embedding model not loaded"
)
return []
try:
# Generate embedding for the search query
query_embedding = self.embedding_model.encode(query).tolist()
# Search for similar queries
results = self.client.search(
collection_name="user_queries",
query_vector=query_embedding,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=limit,
)
# Convert results to list of dicts with similarity scores
similar_queries = []
for result in results:
query_data = dict(result.payload)
query_data["similarity_score"] = result.score
similar_queries.append(query_data)
logger.info(
"similar_queries_found",
user_id=user_id,
count=len(similar_queries)
)
return similar_queries
except Exception as e:
logger.error("search_similar_queries_error", user_id=user_id, error=str(e))
return []
def get_user_queries(
self,
user_id: str,
limit: int = 50,
) -> List[Dict[str, Any]]:
"""Get recent queries by a user."""
try:
results = self.client.scroll(
collection_name="user_queries",
scroll_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=limit,
)
queries = [dict(point.payload) for point in results[0]]
# Sort by created_at descending
queries.sort(key=lambda x: x.get("created_at", 0), reverse=True)
return queries
except Exception as e:
logger.error("get_user_queries_error", user_id=user_id, error=str(e))
return []
def get_all_queries(self, limit: int = 200) -> List[Dict[str, Any]]:
"""Get all queries (for admin dashboard)."""
try:
results = self.client.scroll(
collection_name="user_queries",
limit=limit,
)
queries = [dict(point.payload) for point in results[0]]
queries.sort(key=lambda x: x.get("created_at", 0), reverse=True)
return queries
except Exception as e:
logger.error("get_all_queries_error", error=str(e))
return []
# ==================== FEEDBACK ====================
def store_feedback(
self,
user_id: str,
query_id: Optional[str],
rating: Optional[int],
comment: Optional[str],
feedback_type: Optional[str],
) -> str:
"""Store user feedback."""
feedback_id = str(uuid.uuid4())
feedback_data = {
"feedback_id": feedback_id,
"user_id": user_id,
"query_id": query_id,
"rating": rating,
"comment": comment,
"feedback_type": feedback_type,
"created_at": time.time(),
}
# Use dummy vector
dummy_vector = [0.0] * self.VECTOR_DIM
point = PointStruct(
id=abs(hash(feedback_id)) % (10**10),
vector=dummy_vector,
payload=feedback_data,
)
self.client.upsert(
collection_name="user_feedback",
points=[point],
)
logger.info("feedback_stored", user_id=user_id, feedback_id=feedback_id, rating=rating)
return feedback_id
def get_user_feedback(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""Get feedback by a user."""
try:
results = self.client.scroll(
collection_name="user_feedback",
scroll_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=limit,
)
return [dict(point.payload) for point in results[0]]
except Exception as e:
logger.error("get_user_feedback_error", user_id=user_id, error=str(e))
return []
def get_all_feedback(self, limit: int = 200) -> List[Dict[str, Any]]:
"""Get all feedback (for admin dashboard)."""
try:
results = self.client.scroll(
collection_name="user_feedback",
limit=limit,
)
feedback = [dict(point.payload) for point in results[0]]
feedback.sort(key=lambda x: x.get("created_at", 0), reverse=True)
return feedback
except Exception as e:
logger.error("get_all_feedback_error", error=str(e))
return []
# Global instance
_user_memory_store: Optional[UserMemoryStore] = None
def get_user_memory_store(use_local: bool = False) -> UserMemoryStore:
"""Get the global user memory store instance."""
global _user_memory_store
if _user_memory_store is None:
_user_memory_store = UserMemoryStore(use_local=use_local)
return _user_memory_store