"""YAML artifact indexer for building FAISS index.""" import json import logging from dataclasses import dataclass, asdict from functools import lru_cache from pathlib import Path from typing import Annotated import faiss import numpy as np import yaml from fastapi import Depends from app.config import settings from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError logger = logging.getLogger(__name__) @dataclass class Chunk: """Represents a text chunk from a YAML artifact.""" chunk_id: str content: str chunk_type: str # factual_summary, interpretive_summary, method, invariants artifact_file: str source_file: str tags: dict @dataclass class IndexResult: """Result of indexing operation.""" chunks_indexed: int artifacts_processed: int status: str class ArtifactIndexer: """Parses YAML artifacts and builds FAISS index.""" def __init__(self, embedding_client: EmbeddingClient): """Initialize the indexer. Args: embedding_client: Client for generating embeddings """ self.embedding_client = embedding_client self.artifacts_path = Path(settings.artifacts_path) self.embeddings_path = Path(settings.embeddings_path) self.dimensions = 1536 def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]: """Parse a YAML artifact file into chunks. Args: yaml_path: Path to the YAML file Returns: List of chunks extracted from the file """ chunks = [] artifact_file = str(yaml_path.relative_to(self.artifacts_path)) try: with open(yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) except Exception as e: logger.warning(f"Failed to parse {yaml_path}: {e}") return chunks if not data: return chunks source_file = data.get("source_file", "unknown") tags = data.get("tags", {}) # Chunk 1: Factual summary if factual := data.get("factual_summary"): chunks.append(Chunk( chunk_id=f"{artifact_file}::factual_summary", content=f"[{source_file}] Factual summary: {factual.strip()}", chunk_type="factual_summary", artifact_file=artifact_file, source_file=source_file, tags=tags, )) # Chunk 2: Interpretive summary if interpretive := data.get("interpretive_summary"): chunks.append(Chunk( chunk_id=f"{artifact_file}::interpretive_summary", content=f"[{source_file}] Interpretive summary: {interpretive.strip()}", chunk_type="interpretive_summary", artifact_file=artifact_file, source_file=source_file, tags=tags, )) # Chunk per method if methods := data.get("methods"): for method_sig, method_data in methods.items(): description = method_data.get("description", "") if isinstance(method_data, dict) else method_data if description: chunks.append(Chunk( chunk_id=f"{artifact_file}::method::{method_sig}", content=f"[{source_file}] Method {method_sig}: {description.strip()}", chunk_type="method", artifact_file=artifact_file, source_file=source_file, tags=tags, )) # Chunk for invariants (combined) if invariants := data.get("invariants"): invariants_text = " ".join(f"- {inv}" for inv in invariants) chunks.append(Chunk( chunk_id=f"{artifact_file}::invariants", content=f"[{source_file}] Invariants: {invariants_text}", chunk_type="invariants", artifact_file=artifact_file, source_file=source_file, tags=tags, )) return chunks def _collect_all_chunks(self) -> list[Chunk]: """Collect chunks from all YAML artifacts. Returns: List of all chunks from all artifacts """ all_chunks = [] for yaml_path in self.artifacts_path.rglob("*.yaml"): chunks = self._parse_yaml_to_chunks(yaml_path) all_chunks.extend(chunks) logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}") return all_chunks async def build_index(self) -> IndexResult: """Build FAISS index from all YAML artifacts. Returns: IndexResult with statistics Raises: EmbeddingError: If embedding generation fails """ # Collect all chunks chunks = self._collect_all_chunks() if not chunks: logger.warning("No chunks found in artifacts") return IndexResult( chunks_indexed=0, artifacts_processed=0, status="no_artifacts", ) logger.info(f"Generating embeddings for {len(chunks)} chunks...") # Generate embeddings in batches batch_size = 100 all_embeddings = [] for i in range(0, len(chunks), batch_size): batch = chunks[i:i + batch_size] texts = [chunk.content for chunk in batch] embeddings = await self.embedding_client.embed_batch(texts) all_embeddings.extend(embeddings) logger.debug(f"Embedded batch {i // batch_size + 1}") # Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors) embeddings_array = np.array(all_embeddings, dtype=np.float32) # Normalize for cosine similarity faiss.normalize_L2(embeddings_array) index = faiss.IndexFlatIP(self.dimensions) index.add(embeddings_array) # Create embeddings directory if needed self.embeddings_path.mkdir(parents=True, exist_ok=True) # Save FAISS index faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin")) # Save metadata metadata = { "chunks": [asdict(chunk) for chunk in chunks], } with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) # Save index info artifact_files = set(chunk.artifact_file for chunk in chunks) index_info = { "total_chunks": len(chunks), "total_artifacts": len(artifact_files), "dimensions": self.dimensions, "index_type": "IndexFlatIP", } with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f: json.dump(index_info, f, indent=2) logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts") return IndexResult( chunks_indexed=len(chunks), artifacts_processed=len(artifact_files), status="completed", ) @lru_cache() def get_indexer() -> ArtifactIndexer: """Get cached indexer instance.""" return ArtifactIndexer(embedding_client=get_embedding_client()) IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]