tyndale-ai-service/app/embeddings/indexer.py

228 lines
7.3 KiB
Python

"""YAML artifact indexer for building FAISS index."""
import json
import logging
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
import yaml
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
logger = logging.getLogger(__name__)
@dataclass
class Chunk:
"""Represents a text chunk from a YAML artifact."""
chunk_id: str
content: str
chunk_type: str # factual_summary, interpretive_summary, method, invariants
artifact_file: str
source_file: str
tags: dict
@dataclass
class IndexResult:
"""Result of indexing operation."""
chunks_indexed: int
artifacts_processed: int
status: str
class ArtifactIndexer:
"""Parses YAML artifacts and builds FAISS index."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the indexer.
Args:
embedding_client: Client for generating embeddings
"""
self.embedding_client = embedding_client
self.artifacts_path = Path(settings.artifacts_path)
self.embeddings_path = Path(settings.embeddings_path)
self.dimensions = 1536
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
"""Parse a YAML artifact file into chunks.
Args:
yaml_path: Path to the YAML file
Returns:
List of chunks extracted from the file
"""
chunks = []
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
try:
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except Exception as e:
logger.warning(f"Failed to parse {yaml_path}: {e}")
return chunks
if not data:
return chunks
source_file = data.get("source_file", "unknown")
tags = data.get("tags", {})
# Chunk 1: Factual summary
if factual := data.get("factual_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::factual_summary",
content=f"[{source_file}] Factual summary: {factual.strip()}",
chunk_type="factual_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk 2: Interpretive summary
if interpretive := data.get("interpretive_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::interpretive_summary",
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
chunk_type="interpretive_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk per method
if methods := data.get("methods"):
for method_sig, method_data in methods.items():
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
if description:
chunks.append(Chunk(
chunk_id=f"{artifact_file}::method::{method_sig}",
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
chunk_type="method",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk for invariants (combined)
if invariants := data.get("invariants"):
invariants_text = " ".join(f"- {inv}" for inv in invariants)
chunks.append(Chunk(
chunk_id=f"{artifact_file}::invariants",
content=f"[{source_file}] Invariants: {invariants_text}",
chunk_type="invariants",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
return chunks
def _collect_all_chunks(self) -> list[Chunk]:
"""Collect chunks from all YAML artifacts.
Returns:
List of all chunks from all artifacts
"""
all_chunks = []
for yaml_path in self.artifacts_path.rglob("*.yaml"):
chunks = self._parse_yaml_to_chunks(yaml_path)
all_chunks.extend(chunks)
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
return all_chunks
async def build_index(self) -> IndexResult:
"""Build FAISS index from all YAML artifacts.
Returns:
IndexResult with statistics
Raises:
EmbeddingError: If embedding generation fails
"""
# Collect all chunks
chunks = self._collect_all_chunks()
if not chunks:
logger.warning("No chunks found in artifacts")
return IndexResult(
chunks_indexed=0,
artifacts_processed=0,
status="no_artifacts",
)
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
# Generate embeddings in batches
batch_size = 100
all_embeddings = []
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [chunk.content for chunk in batch]
embeddings = await self.embedding_client.embed_batch(texts)
all_embeddings.extend(embeddings)
logger.debug(f"Embedded batch {i // batch_size + 1}")
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings_array)
index = faiss.IndexFlatIP(self.dimensions)
index.add(embeddings_array)
# Create embeddings directory if needed
self.embeddings_path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
# Save metadata
metadata = {
"chunks": [asdict(chunk) for chunk in chunks],
}
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
# Save index info
artifact_files = set(chunk.artifact_file for chunk in chunks)
index_info = {
"total_chunks": len(chunks),
"total_artifacts": len(artifact_files),
"dimensions": self.dimensions,
"index_type": "IndexFlatIP",
}
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
json.dump(index_info, f, indent=2)
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
return IndexResult(
chunks_indexed=len(chunks),
artifacts_processed=len(artifact_files),
status="completed",
)
@lru_cache()
def get_indexer() -> ArtifactIndexer:
"""Get cached indexer instance."""
return ArtifactIndexer(embedding_client=get_embedding_client())
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]