228 lines
7.3 KiB
Python
228 lines
7.3 KiB
Python
"""YAML artifact indexer for building FAISS index."""
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, asdict
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Annotated
|
|
|
|
import faiss
|
|
import numpy as np
|
|
import yaml
|
|
from fastapi import Depends
|
|
|
|
from app.config import settings
|
|
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Chunk:
|
|
"""Represents a text chunk from a YAML artifact."""
|
|
|
|
chunk_id: str
|
|
content: str
|
|
chunk_type: str # factual_summary, interpretive_summary, method, invariants
|
|
artifact_file: str
|
|
source_file: str
|
|
tags: dict
|
|
|
|
|
|
@dataclass
|
|
class IndexResult:
|
|
"""Result of indexing operation."""
|
|
|
|
chunks_indexed: int
|
|
artifacts_processed: int
|
|
status: str
|
|
|
|
|
|
class ArtifactIndexer:
|
|
"""Parses YAML artifacts and builds FAISS index."""
|
|
|
|
def __init__(self, embedding_client: EmbeddingClient):
|
|
"""Initialize the indexer.
|
|
|
|
Args:
|
|
embedding_client: Client for generating embeddings
|
|
"""
|
|
self.embedding_client = embedding_client
|
|
self.artifacts_path = Path(settings.artifacts_path)
|
|
self.embeddings_path = Path(settings.embeddings_path)
|
|
self.dimensions = 1536
|
|
|
|
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
|
|
"""Parse a YAML artifact file into chunks.
|
|
|
|
Args:
|
|
yaml_path: Path to the YAML file
|
|
|
|
Returns:
|
|
List of chunks extracted from the file
|
|
"""
|
|
chunks = []
|
|
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
|
|
|
|
try:
|
|
with open(yaml_path, "r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse {yaml_path}: {e}")
|
|
return chunks
|
|
|
|
if not data:
|
|
return chunks
|
|
|
|
source_file = data.get("source_file", "unknown")
|
|
tags = data.get("tags", {})
|
|
|
|
# Chunk 1: Factual summary
|
|
if factual := data.get("factual_summary"):
|
|
chunks.append(Chunk(
|
|
chunk_id=f"{artifact_file}::factual_summary",
|
|
content=f"[{source_file}] Factual summary: {factual.strip()}",
|
|
chunk_type="factual_summary",
|
|
artifact_file=artifact_file,
|
|
source_file=source_file,
|
|
tags=tags,
|
|
))
|
|
|
|
# Chunk 2: Interpretive summary
|
|
if interpretive := data.get("interpretive_summary"):
|
|
chunks.append(Chunk(
|
|
chunk_id=f"{artifact_file}::interpretive_summary",
|
|
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
|
|
chunk_type="interpretive_summary",
|
|
artifact_file=artifact_file,
|
|
source_file=source_file,
|
|
tags=tags,
|
|
))
|
|
|
|
# Chunk per method
|
|
if methods := data.get("methods"):
|
|
for method_sig, method_data in methods.items():
|
|
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
|
|
if description:
|
|
chunks.append(Chunk(
|
|
chunk_id=f"{artifact_file}::method::{method_sig}",
|
|
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
|
|
chunk_type="method",
|
|
artifact_file=artifact_file,
|
|
source_file=source_file,
|
|
tags=tags,
|
|
))
|
|
|
|
# Chunk for invariants (combined)
|
|
if invariants := data.get("invariants"):
|
|
invariants_text = " ".join(f"- {inv}" for inv in invariants)
|
|
chunks.append(Chunk(
|
|
chunk_id=f"{artifact_file}::invariants",
|
|
content=f"[{source_file}] Invariants: {invariants_text}",
|
|
chunk_type="invariants",
|
|
artifact_file=artifact_file,
|
|
source_file=source_file,
|
|
tags=tags,
|
|
))
|
|
|
|
return chunks
|
|
|
|
def _collect_all_chunks(self) -> list[Chunk]:
|
|
"""Collect chunks from all YAML artifacts.
|
|
|
|
Returns:
|
|
List of all chunks from all artifacts
|
|
"""
|
|
all_chunks = []
|
|
|
|
for yaml_path in self.artifacts_path.rglob("*.yaml"):
|
|
chunks = self._parse_yaml_to_chunks(yaml_path)
|
|
all_chunks.extend(chunks)
|
|
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
|
|
|
|
return all_chunks
|
|
|
|
async def build_index(self) -> IndexResult:
|
|
"""Build FAISS index from all YAML artifacts.
|
|
|
|
Returns:
|
|
IndexResult with statistics
|
|
|
|
Raises:
|
|
EmbeddingError: If embedding generation fails
|
|
"""
|
|
# Collect all chunks
|
|
chunks = self._collect_all_chunks()
|
|
|
|
if not chunks:
|
|
logger.warning("No chunks found in artifacts")
|
|
return IndexResult(
|
|
chunks_indexed=0,
|
|
artifacts_processed=0,
|
|
status="no_artifacts",
|
|
)
|
|
|
|
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
|
|
|
|
# Generate embeddings in batches
|
|
batch_size = 100
|
|
all_embeddings = []
|
|
|
|
for i in range(0, len(chunks), batch_size):
|
|
batch = chunks[i:i + batch_size]
|
|
texts = [chunk.content for chunk in batch]
|
|
embeddings = await self.embedding_client.embed_batch(texts)
|
|
all_embeddings.extend(embeddings)
|
|
logger.debug(f"Embedded batch {i // batch_size + 1}")
|
|
|
|
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
|
|
embeddings_array = np.array(all_embeddings, dtype=np.float32)
|
|
|
|
# Normalize for cosine similarity
|
|
faiss.normalize_L2(embeddings_array)
|
|
|
|
index = faiss.IndexFlatIP(self.dimensions)
|
|
index.add(embeddings_array)
|
|
|
|
# Create embeddings directory if needed
|
|
self.embeddings_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save FAISS index
|
|
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
"chunks": [asdict(chunk) for chunk in chunks],
|
|
}
|
|
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
# Save index info
|
|
artifact_files = set(chunk.artifact_file for chunk in chunks)
|
|
index_info = {
|
|
"total_chunks": len(chunks),
|
|
"total_artifacts": len(artifact_files),
|
|
"dimensions": self.dimensions,
|
|
"index_type": "IndexFlatIP",
|
|
}
|
|
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
|
|
json.dump(index_info, f, indent=2)
|
|
|
|
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
|
|
|
|
return IndexResult(
|
|
chunks_indexed=len(chunks),
|
|
artifacts_processed=len(artifact_files),
|
|
status="completed",
|
|
)
|
|
|
|
|
|
@lru_cache()
|
|
def get_indexer() -> ArtifactIndexer:
|
|
"""Get cached indexer instance."""
|
|
return ArtifactIndexer(embedding_client=get_embedding_client())
|
|
|
|
|
|
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]
|