222 lines
6.5 KiB
Python
222 lines
6.5 KiB
Python
"""FAISS-based retrieval with adaptive selection."""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
import faiss
|
|
import numpy as np
|
|
from fastapi import Depends
|
|
|
|
from app.config import settings
|
|
from app.embeddings.client import EmbeddingClient, get_embedding_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RetrievedChunk:
|
|
"""A chunk retrieved from FAISS search."""
|
|
|
|
chunk_id: str
|
|
content: str
|
|
chunk_type: str
|
|
artifact_file: str
|
|
source_file: str
|
|
tags: dict
|
|
score: float
|
|
|
|
|
|
class Retriever:
|
|
"""FAISS-based retriever with adaptive selection logic."""
|
|
|
|
def __init__(self, embedding_client: EmbeddingClient):
|
|
"""Initialize the retriever.
|
|
|
|
Args:
|
|
embedding_client: Client for generating query embeddings
|
|
"""
|
|
self.embedding_client = embedding_client
|
|
self.embeddings_path = Path(settings.embeddings_path)
|
|
self.top_k = settings.rag_top_k
|
|
self.threshold = settings.rag_similarity_threshold
|
|
|
|
self._index: faiss.IndexFlatIP | None = None
|
|
self._metadata: list[dict] | None = None
|
|
self._loaded = False
|
|
|
|
def load_index(self) -> bool:
|
|
"""Load FAISS index and metadata from disk.
|
|
|
|
Returns:
|
|
True if successfully loaded, False otherwise
|
|
"""
|
|
index_path = self.embeddings_path / "faiss_index.bin"
|
|
metadata_path = self.embeddings_path / "metadata.json"
|
|
|
|
if not index_path.exists() or not metadata_path.exists():
|
|
logger.warning("FAISS index or metadata not found. Run /index first.")
|
|
self._loaded = False
|
|
return False
|
|
|
|
try:
|
|
self._index = faiss.read_index(str(index_path))
|
|
|
|
with open(metadata_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self._metadata = data.get("chunks", [])
|
|
|
|
self._loaded = True
|
|
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load FAISS index: {e}")
|
|
self._loaded = False
|
|
return False
|
|
|
|
@property
|
|
def is_loaded(self) -> bool:
|
|
"""Check if index is loaded."""
|
|
return self._loaded and self._index is not None
|
|
|
|
@property
|
|
def index_size(self) -> int:
|
|
"""Get number of vectors in index."""
|
|
if self._index is None:
|
|
return 0
|
|
return self._index.ntotal
|
|
|
|
def _adaptive_select(
|
|
self,
|
|
indices: np.ndarray,
|
|
scores: np.ndarray,
|
|
) -> list[tuple[int, float]]:
|
|
"""Apply adaptive selection logic.
|
|
|
|
- Always include top 2 chunks (regardless of score)
|
|
- For chunks 3-5: apply threshold
|
|
- Limit to self.top_k chunks total
|
|
|
|
Args:
|
|
indices: FAISS result indices
|
|
scores: FAISS result scores
|
|
|
|
Returns:
|
|
List of (index, score) tuples for selected chunks
|
|
"""
|
|
selected = []
|
|
|
|
for i, (idx, score) in enumerate(zip(indices, scores)):
|
|
if idx == -1: # FAISS returns -1 for no match
|
|
continue
|
|
|
|
# Always take top 2
|
|
if i < 2:
|
|
selected.append((int(idx), float(score)))
|
|
# Apply threshold for remaining
|
|
elif score >= self.threshold and len(selected) < self.top_k:
|
|
selected.append((int(idx), float(score)))
|
|
|
|
return selected
|
|
|
|
def _apply_diversity_filter(
|
|
self,
|
|
candidates: list[tuple[int, float]],
|
|
max_per_artifact: int = 2,
|
|
) -> list[tuple[int, float]]:
|
|
"""Limit chunks per artifact for diversity.
|
|
|
|
Args:
|
|
candidates: List of (index, score) tuples
|
|
max_per_artifact: Maximum chunks from same artifact
|
|
|
|
Returns:
|
|
Filtered list of (index, score) tuples
|
|
"""
|
|
artifact_counts: dict[str, int] = {}
|
|
filtered = []
|
|
|
|
for idx, score in candidates:
|
|
chunk = self._metadata[idx]
|
|
artifact = chunk["artifact_file"]
|
|
|
|
if artifact_counts.get(artifact, 0) < max_per_artifact:
|
|
filtered.append((idx, score))
|
|
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
|
|
|
|
return filtered
|
|
|
|
async def search(self, query: str) -> list[RetrievedChunk]:
|
|
"""Search for relevant chunks.
|
|
|
|
Args:
|
|
query: User's question
|
|
|
|
Returns:
|
|
List of retrieved chunks with relevance scores
|
|
"""
|
|
if not self.is_loaded:
|
|
if not self.load_index():
|
|
return []
|
|
|
|
# Generate query embedding
|
|
query_embedding = await self.embedding_client.embed(query)
|
|
query_vector = np.array([query_embedding], dtype=np.float32)
|
|
|
|
# Normalize for cosine similarity
|
|
faiss.normalize_L2(query_vector)
|
|
|
|
# Search FAISS (get more candidates than needed for filtering)
|
|
k_search = min(8, self._index.ntotal)
|
|
scores, indices = self._index.search(query_vector, k_search)
|
|
|
|
# Apply adaptive selection
|
|
selected = self._adaptive_select(indices[0], scores[0])
|
|
|
|
# Apply diversity filter
|
|
filtered = self._apply_diversity_filter(selected)
|
|
|
|
# Build result chunks
|
|
results = []
|
|
for idx, score in filtered:
|
|
chunk_data = self._metadata[idx]
|
|
results.append(RetrievedChunk(
|
|
chunk_id=chunk_data["chunk_id"],
|
|
content=chunk_data["content"],
|
|
chunk_type=chunk_data["chunk_type"],
|
|
artifact_file=chunk_data["artifact_file"],
|
|
source_file=chunk_data["source_file"],
|
|
tags=chunk_data.get("tags", {}),
|
|
score=score,
|
|
))
|
|
|
|
logger.debug(f"Retrieved {len(results)} chunks for query")
|
|
return results
|
|
|
|
|
|
# Singleton retriever instance
|
|
_retriever: Retriever | None = None
|
|
|
|
|
|
def get_retriever() -> Retriever:
|
|
"""Get singleton retriever instance (lazily initialized)."""
|
|
global _retriever
|
|
if _retriever is None:
|
|
_retriever = Retriever(embedding_client=get_embedding_client())
|
|
# Attempt to load index at startup
|
|
_retriever.load_index()
|
|
return _retriever
|
|
|
|
|
|
def reset_retriever() -> None:
|
|
"""Reset the singleton retriever (for reloading after re-indexing)."""
|
|
global _retriever
|
|
_retriever = None
|
|
|
|
|
|
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]
|