"""FAISS-based retrieval with adaptive selection.""" import json import logging from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Annotated import faiss import numpy as np from fastapi import Depends from app.config import settings from app.embeddings.client import EmbeddingClient, get_embedding_client logger = logging.getLogger(__name__) @dataclass class RetrievedChunk: """A chunk retrieved from FAISS search.""" chunk_id: str content: str chunk_type: str artifact_file: str source_file: str tags: dict score: float class Retriever: """FAISS-based retriever with adaptive selection logic.""" def __init__(self, embedding_client: EmbeddingClient): """Initialize the retriever. Args: embedding_client: Client for generating query embeddings """ self.embedding_client = embedding_client self.embeddings_path = Path(settings.embeddings_path) self.top_k = settings.rag_top_k self.threshold = settings.rag_similarity_threshold self._index: faiss.IndexFlatIP | None = None self._metadata: list[dict] | None = None self._loaded = False def load_index(self) -> bool: """Load FAISS index and metadata from disk. Returns: True if successfully loaded, False otherwise """ index_path = self.embeddings_path / "faiss_index.bin" metadata_path = self.embeddings_path / "metadata.json" if not index_path.exists() or not metadata_path.exists(): logger.warning("FAISS index or metadata not found. Run /index first.") self._loaded = False return False try: self._index = faiss.read_index(str(index_path)) with open(metadata_path, "r", encoding="utf-8") as f: data = json.load(f) self._metadata = data.get("chunks", []) self._loaded = True logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors") return True except Exception as e: logger.error(f"Failed to load FAISS index: {e}") self._loaded = False return False @property def is_loaded(self) -> bool: """Check if index is loaded.""" return self._loaded and self._index is not None @property def index_size(self) -> int: """Get number of vectors in index.""" if self._index is None: return 0 return self._index.ntotal def _adaptive_select( self, indices: np.ndarray, scores: np.ndarray, ) -> list[tuple[int, float]]: """Apply adaptive selection logic. - Always include top 2 chunks (regardless of score) - For chunks 3-5: apply threshold - Limit to self.top_k chunks total Args: indices: FAISS result indices scores: FAISS result scores Returns: List of (index, score) tuples for selected chunks """ selected = [] for i, (idx, score) in enumerate(zip(indices, scores)): if idx == -1: # FAISS returns -1 for no match continue # Always take top 2 if i < 2: selected.append((int(idx), float(score))) # Apply threshold for remaining elif score >= self.threshold and len(selected) < self.top_k: selected.append((int(idx), float(score))) return selected def _apply_diversity_filter( self, candidates: list[tuple[int, float]], max_per_artifact: int = 2, ) -> list[tuple[int, float]]: """Limit chunks per artifact for diversity. Args: candidates: List of (index, score) tuples max_per_artifact: Maximum chunks from same artifact Returns: Filtered list of (index, score) tuples """ artifact_counts: dict[str, int] = {} filtered = [] for idx, score in candidates: chunk = self._metadata[idx] artifact = chunk["artifact_file"] if artifact_counts.get(artifact, 0) < max_per_artifact: filtered.append((idx, score)) artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1 return filtered async def search(self, query: str) -> list[RetrievedChunk]: """Search for relevant chunks. Args: query: User's question Returns: List of retrieved chunks with relevance scores """ if not self.is_loaded: if not self.load_index(): return [] # Generate query embedding query_embedding = await self.embedding_client.embed(query) query_vector = np.array([query_embedding], dtype=np.float32) # Normalize for cosine similarity faiss.normalize_L2(query_vector) # Search FAISS (get more candidates than needed for filtering) k_search = min(8, self._index.ntotal) scores, indices = self._index.search(query_vector, k_search) # Apply adaptive selection selected = self._adaptive_select(indices[0], scores[0]) # Apply diversity filter filtered = self._apply_diversity_filter(selected) # Build result chunks results = [] for idx, score in filtered: chunk_data = self._metadata[idx] results.append(RetrievedChunk( chunk_id=chunk_data["chunk_id"], content=chunk_data["content"], chunk_type=chunk_data["chunk_type"], artifact_file=chunk_data["artifact_file"], source_file=chunk_data["source_file"], tags=chunk_data.get("tags", {}), score=score, )) logger.debug(f"Retrieved {len(results)} chunks for query") return results # Singleton retriever instance _retriever: Retriever | None = None def get_retriever() -> Retriever: """Get singleton retriever instance (lazily initialized).""" global _retriever if _retriever is None: _retriever = Retriever(embedding_client=get_embedding_client()) # Attempt to load index at startup _retriever.load_index() return _retriever def reset_retriever() -> None: """Reset the singleton retriever (for reloading after re-indexing).""" global _retriever _retriever = None RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]