tyndale-ai-service/app/embeddings/retriever.py

222 lines
6.5 KiB
Python

"""FAISS-based retrieval with adaptive selection."""
import json
import logging
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client
logger = logging.getLogger(__name__)
@dataclass
class RetrievedChunk:
"""A chunk retrieved from FAISS search."""
chunk_id: str
content: str
chunk_type: str
artifact_file: str
source_file: str
tags: dict
score: float
class Retriever:
"""FAISS-based retriever with adaptive selection logic."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the retriever.
Args:
embedding_client: Client for generating query embeddings
"""
self.embedding_client = embedding_client
self.embeddings_path = Path(settings.embeddings_path)
self.top_k = settings.rag_top_k
self.threshold = settings.rag_similarity_threshold
self._index: faiss.IndexFlatIP | None = None
self._metadata: list[dict] | None = None
self._loaded = False
def load_index(self) -> bool:
"""Load FAISS index and metadata from disk.
Returns:
True if successfully loaded, False otherwise
"""
index_path = self.embeddings_path / "faiss_index.bin"
metadata_path = self.embeddings_path / "metadata.json"
if not index_path.exists() or not metadata_path.exists():
logger.warning("FAISS index or metadata not found. Run /index first.")
self._loaded = False
return False
try:
self._index = faiss.read_index(str(index_path))
with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f)
self._metadata = data.get("chunks", [])
self._loaded = True
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
return True
except Exception as e:
logger.error(f"Failed to load FAISS index: {e}")
self._loaded = False
return False
@property
def is_loaded(self) -> bool:
"""Check if index is loaded."""
return self._loaded and self._index is not None
@property
def index_size(self) -> int:
"""Get number of vectors in index."""
if self._index is None:
return 0
return self._index.ntotal
def _adaptive_select(
self,
indices: np.ndarray,
scores: np.ndarray,
) -> list[tuple[int, float]]:
"""Apply adaptive selection logic.
- Always include top 2 chunks (regardless of score)
- For chunks 3-5: apply threshold
- Limit to self.top_k chunks total
Args:
indices: FAISS result indices
scores: FAISS result scores
Returns:
List of (index, score) tuples for selected chunks
"""
selected = []
for i, (idx, score) in enumerate(zip(indices, scores)):
if idx == -1: # FAISS returns -1 for no match
continue
# Always take top 2
if i < 2:
selected.append((int(idx), float(score)))
# Apply threshold for remaining
elif score >= self.threshold and len(selected) < self.top_k:
selected.append((int(idx), float(score)))
return selected
def _apply_diversity_filter(
self,
candidates: list[tuple[int, float]],
max_per_artifact: int = 2,
) -> list[tuple[int, float]]:
"""Limit chunks per artifact for diversity.
Args:
candidates: List of (index, score) tuples
max_per_artifact: Maximum chunks from same artifact
Returns:
Filtered list of (index, score) tuples
"""
artifact_counts: dict[str, int] = {}
filtered = []
for idx, score in candidates:
chunk = self._metadata[idx]
artifact = chunk["artifact_file"]
if artifact_counts.get(artifact, 0) < max_per_artifact:
filtered.append((idx, score))
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
return filtered
async def search(self, query: str) -> list[RetrievedChunk]:
"""Search for relevant chunks.
Args:
query: User's question
Returns:
List of retrieved chunks with relevance scores
"""
if not self.is_loaded:
if not self.load_index():
return []
# Generate query embedding
query_embedding = await self.embedding_client.embed(query)
query_vector = np.array([query_embedding], dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(query_vector)
# Search FAISS (get more candidates than needed for filtering)
k_search = min(8, self._index.ntotal)
scores, indices = self._index.search(query_vector, k_search)
# Apply adaptive selection
selected = self._adaptive_select(indices[0], scores[0])
# Apply diversity filter
filtered = self._apply_diversity_filter(selected)
# Build result chunks
results = []
for idx, score in filtered:
chunk_data = self._metadata[idx]
results.append(RetrievedChunk(
chunk_id=chunk_data["chunk_id"],
content=chunk_data["content"],
chunk_type=chunk_data["chunk_type"],
artifact_file=chunk_data["artifact_file"],
source_file=chunk_data["source_file"],
tags=chunk_data.get("tags", {}),
score=score,
))
logger.debug(f"Retrieved {len(results)} chunks for query")
return results
# Singleton retriever instance
_retriever: Retriever | None = None
def get_retriever() -> Retriever:
"""Get singleton retriever instance (lazily initialized)."""
global _retriever
if _retriever is None:
_retriever = Retriever(embedding_client=get_embedding_client())
# Attempt to load index at startup
_retriever.load_index()
return _retriever
def reset_retriever() -> None:
"""Reset the singleton retriever (for reloading after re-indexing)."""
global _retriever
_retriever = None
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]