feat: replace Redis with in-memory conversation storage
- Remove Redis dependency and redis_client.py - Implement ConversationMemory with module-level dictionary - Add TTL support via timestamp checking - Remove redis_connected from health endpoint - Add embeddings, intent classification, and RAG prompt modules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
"""Embeddings module for YAML artifact indexing and retrieval."""
|
||||
|
||||
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingClientDependency
|
||||
from app.embeddings.indexer import ArtifactIndexer, get_indexer, IndexerDependency
|
||||
from app.embeddings.retriever import Retriever, get_retriever, RetrieverDependency, reset_retriever
|
||||
|
||||
__all__ = [
|
||||
"EmbeddingClient",
|
||||
"get_embedding_client",
|
||||
"EmbeddingClientDependency",
|
||||
"ArtifactIndexer",
|
||||
"get_indexer",
|
||||
"IndexerDependency",
|
||||
"Retriever",
|
||||
"get_retriever",
|
||||
"RetrieverDependency",
|
||||
"reset_retriever",
|
||||
]
|
||||
@@ -0,0 +1,109 @@
|
||||
"""OpenAI embedding client wrapper."""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingError(Exception):
|
||||
"""Base exception for embedding operations."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
"""Async wrapper for OpenAI embeddings API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
||||
"""Initialize the embedding client.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Embedding model identifier
|
||||
"""
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
self.dimensions = 1536 # text-embedding-3-small dimension
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
Embedding vector (1536 dimensions)
|
||||
|
||||
Raises:
|
||||
EmbeddingError: If embedding generation fails
|
||||
"""
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=text,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
except AuthenticationError as e:
|
||||
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
||||
except RateLimitError as e:
|
||||
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
||||
except APIConnectionError as e:
|
||||
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
||||
except APIError as e:
|
||||
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
||||
|
||||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
EmbeddingError: If embedding generation fails
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=self.model,
|
||||
input=texts,
|
||||
)
|
||||
# Sort by index to ensure correct ordering
|
||||
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
|
||||
return [item.embedding for item in sorted_embeddings]
|
||||
|
||||
except AuthenticationError as e:
|
||||
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
||||
except RateLimitError as e:
|
||||
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
||||
except APIConnectionError as e:
|
||||
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
||||
except APIError as e:
|
||||
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_embedding_client() -> EmbeddingClient:
|
||||
"""Get cached embedding client instance."""
|
||||
return EmbeddingClient(
|
||||
api_key=settings.openai_api_key,
|
||||
model=settings.embedding_model,
|
||||
)
|
||||
|
||||
|
||||
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]
|
||||
@@ -0,0 +1,227 @@
|
||||
"""YAML artifact indexer for building FAISS index."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, asdict
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import yaml
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config import settings
|
||||
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""Represents a text chunk from a YAML artifact."""
|
||||
|
||||
chunk_id: str
|
||||
content: str
|
||||
chunk_type: str # factual_summary, interpretive_summary, method, invariants
|
||||
artifact_file: str
|
||||
source_file: str
|
||||
tags: dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
"""Result of indexing operation."""
|
||||
|
||||
chunks_indexed: int
|
||||
artifacts_processed: int
|
||||
status: str
|
||||
|
||||
|
||||
class ArtifactIndexer:
|
||||
"""Parses YAML artifacts and builds FAISS index."""
|
||||
|
||||
def __init__(self, embedding_client: EmbeddingClient):
|
||||
"""Initialize the indexer.
|
||||
|
||||
Args:
|
||||
embedding_client: Client for generating embeddings
|
||||
"""
|
||||
self.embedding_client = embedding_client
|
||||
self.artifacts_path = Path(settings.artifacts_path)
|
||||
self.embeddings_path = Path(settings.embeddings_path)
|
||||
self.dimensions = 1536
|
||||
|
||||
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
|
||||
"""Parse a YAML artifact file into chunks.
|
||||
|
||||
Args:
|
||||
yaml_path: Path to the YAML file
|
||||
|
||||
Returns:
|
||||
List of chunks extracted from the file
|
||||
"""
|
||||
chunks = []
|
||||
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
|
||||
|
||||
try:
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse {yaml_path}: {e}")
|
||||
return chunks
|
||||
|
||||
if not data:
|
||||
return chunks
|
||||
|
||||
source_file = data.get("source_file", "unknown")
|
||||
tags = data.get("tags", {})
|
||||
|
||||
# Chunk 1: Factual summary
|
||||
if factual := data.get("factual_summary"):
|
||||
chunks.append(Chunk(
|
||||
chunk_id=f"{artifact_file}::factual_summary",
|
||||
content=f"[{source_file}] Factual summary: {factual.strip()}",
|
||||
chunk_type="factual_summary",
|
||||
artifact_file=artifact_file,
|
||||
source_file=source_file,
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
# Chunk 2: Interpretive summary
|
||||
if interpretive := data.get("interpretive_summary"):
|
||||
chunks.append(Chunk(
|
||||
chunk_id=f"{artifact_file}::interpretive_summary",
|
||||
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
|
||||
chunk_type="interpretive_summary",
|
||||
artifact_file=artifact_file,
|
||||
source_file=source_file,
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
# Chunk per method
|
||||
if methods := data.get("methods"):
|
||||
for method_sig, method_data in methods.items():
|
||||
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
|
||||
if description:
|
||||
chunks.append(Chunk(
|
||||
chunk_id=f"{artifact_file}::method::{method_sig}",
|
||||
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
|
||||
chunk_type="method",
|
||||
artifact_file=artifact_file,
|
||||
source_file=source_file,
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
# Chunk for invariants (combined)
|
||||
if invariants := data.get("invariants"):
|
||||
invariants_text = " ".join(f"- {inv}" for inv in invariants)
|
||||
chunks.append(Chunk(
|
||||
chunk_id=f"{artifact_file}::invariants",
|
||||
content=f"[{source_file}] Invariants: {invariants_text}",
|
||||
chunk_type="invariants",
|
||||
artifact_file=artifact_file,
|
||||
source_file=source_file,
|
||||
tags=tags,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
def _collect_all_chunks(self) -> list[Chunk]:
|
||||
"""Collect chunks from all YAML artifacts.
|
||||
|
||||
Returns:
|
||||
List of all chunks from all artifacts
|
||||
"""
|
||||
all_chunks = []
|
||||
|
||||
for yaml_path in self.artifacts_path.rglob("*.yaml"):
|
||||
chunks = self._parse_yaml_to_chunks(yaml_path)
|
||||
all_chunks.extend(chunks)
|
||||
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
|
||||
|
||||
return all_chunks
|
||||
|
||||
async def build_index(self) -> IndexResult:
|
||||
"""Build FAISS index from all YAML artifacts.
|
||||
|
||||
Returns:
|
||||
IndexResult with statistics
|
||||
|
||||
Raises:
|
||||
EmbeddingError: If embedding generation fails
|
||||
"""
|
||||
# Collect all chunks
|
||||
chunks = self._collect_all_chunks()
|
||||
|
||||
if not chunks:
|
||||
logger.warning("No chunks found in artifacts")
|
||||
return IndexResult(
|
||||
chunks_indexed=0,
|
||||
artifacts_processed=0,
|
||||
status="no_artifacts",
|
||||
)
|
||||
|
||||
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
|
||||
|
||||
# Generate embeddings in batches
|
||||
batch_size = 100
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(chunks), batch_size):
|
||||
batch = chunks[i:i + batch_size]
|
||||
texts = [chunk.content for chunk in batch]
|
||||
embeddings = await self.embedding_client.embed_batch(texts)
|
||||
all_embeddings.extend(embeddings)
|
||||
logger.debug(f"Embedded batch {i // batch_size + 1}")
|
||||
|
||||
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
|
||||
embeddings_array = np.array(all_embeddings, dtype=np.float32)
|
||||
|
||||
# Normalize for cosine similarity
|
||||
faiss.normalize_L2(embeddings_array)
|
||||
|
||||
index = faiss.IndexFlatIP(self.dimensions)
|
||||
index.add(embeddings_array)
|
||||
|
||||
# Create embeddings directory if needed
|
||||
self.embeddings_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save FAISS index
|
||||
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
"chunks": [asdict(chunk) for chunk in chunks],
|
||||
}
|
||||
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
# Save index info
|
||||
artifact_files = set(chunk.artifact_file for chunk in chunks)
|
||||
index_info = {
|
||||
"total_chunks": len(chunks),
|
||||
"total_artifacts": len(artifact_files),
|
||||
"dimensions": self.dimensions,
|
||||
"index_type": "IndexFlatIP",
|
||||
}
|
||||
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
|
||||
json.dump(index_info, f, indent=2)
|
||||
|
||||
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
|
||||
|
||||
return IndexResult(
|
||||
chunks_indexed=len(chunks),
|
||||
artifacts_processed=len(artifact_files),
|
||||
status="completed",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_indexer() -> ArtifactIndexer:
|
||||
"""Get cached indexer instance."""
|
||||
return ArtifactIndexer(embedding_client=get_embedding_client())
|
||||
|
||||
|
||||
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]
|
||||
@@ -0,0 +1,221 @@
|
||||
"""FAISS-based retrieval with adaptive selection."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from fastapi import Depends
|
||||
|
||||
from app.config import settings
|
||||
from app.embeddings.client import EmbeddingClient, get_embedding_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedChunk:
|
||||
"""A chunk retrieved from FAISS search."""
|
||||
|
||||
chunk_id: str
|
||||
content: str
|
||||
chunk_type: str
|
||||
artifact_file: str
|
||||
source_file: str
|
||||
tags: dict
|
||||
score: float
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""FAISS-based retriever with adaptive selection logic."""
|
||||
|
||||
def __init__(self, embedding_client: EmbeddingClient):
|
||||
"""Initialize the retriever.
|
||||
|
||||
Args:
|
||||
embedding_client: Client for generating query embeddings
|
||||
"""
|
||||
self.embedding_client = embedding_client
|
||||
self.embeddings_path = Path(settings.embeddings_path)
|
||||
self.top_k = settings.rag_top_k
|
||||
self.threshold = settings.rag_similarity_threshold
|
||||
|
||||
self._index: faiss.IndexFlatIP | None = None
|
||||
self._metadata: list[dict] | None = None
|
||||
self._loaded = False
|
||||
|
||||
def load_index(self) -> bool:
|
||||
"""Load FAISS index and metadata from disk.
|
||||
|
||||
Returns:
|
||||
True if successfully loaded, False otherwise
|
||||
"""
|
||||
index_path = self.embeddings_path / "faiss_index.bin"
|
||||
metadata_path = self.embeddings_path / "metadata.json"
|
||||
|
||||
if not index_path.exists() or not metadata_path.exists():
|
||||
logger.warning("FAISS index or metadata not found. Run /index first.")
|
||||
self._loaded = False
|
||||
return False
|
||||
|
||||
try:
|
||||
self._index = faiss.read_index(str(index_path))
|
||||
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
self._metadata = data.get("chunks", [])
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load FAISS index: {e}")
|
||||
self._loaded = False
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if index is loaded."""
|
||||
return self._loaded and self._index is not None
|
||||
|
||||
@property
|
||||
def index_size(self) -> int:
|
||||
"""Get number of vectors in index."""
|
||||
if self._index is None:
|
||||
return 0
|
||||
return self._index.ntotal
|
||||
|
||||
def _adaptive_select(
|
||||
self,
|
||||
indices: np.ndarray,
|
||||
scores: np.ndarray,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Apply adaptive selection logic.
|
||||
|
||||
- Always include top 2 chunks (regardless of score)
|
||||
- For chunks 3-5: apply threshold
|
||||
- Limit to self.top_k chunks total
|
||||
|
||||
Args:
|
||||
indices: FAISS result indices
|
||||
scores: FAISS result scores
|
||||
|
||||
Returns:
|
||||
List of (index, score) tuples for selected chunks
|
||||
"""
|
||||
selected = []
|
||||
|
||||
for i, (idx, score) in enumerate(zip(indices, scores)):
|
||||
if idx == -1: # FAISS returns -1 for no match
|
||||
continue
|
||||
|
||||
# Always take top 2
|
||||
if i < 2:
|
||||
selected.append((int(idx), float(score)))
|
||||
# Apply threshold for remaining
|
||||
elif score >= self.threshold and len(selected) < self.top_k:
|
||||
selected.append((int(idx), float(score)))
|
||||
|
||||
return selected
|
||||
|
||||
def _apply_diversity_filter(
|
||||
self,
|
||||
candidates: list[tuple[int, float]],
|
||||
max_per_artifact: int = 2,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Limit chunks per artifact for diversity.
|
||||
|
||||
Args:
|
||||
candidates: List of (index, score) tuples
|
||||
max_per_artifact: Maximum chunks from same artifact
|
||||
|
||||
Returns:
|
||||
Filtered list of (index, score) tuples
|
||||
"""
|
||||
artifact_counts: dict[str, int] = {}
|
||||
filtered = []
|
||||
|
||||
for idx, score in candidates:
|
||||
chunk = self._metadata[idx]
|
||||
artifact = chunk["artifact_file"]
|
||||
|
||||
if artifact_counts.get(artifact, 0) < max_per_artifact:
|
||||
filtered.append((idx, score))
|
||||
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
|
||||
|
||||
return filtered
|
||||
|
||||
async def search(self, query: str) -> list[RetrievedChunk]:
|
||||
"""Search for relevant chunks.
|
||||
|
||||
Args:
|
||||
query: User's question
|
||||
|
||||
Returns:
|
||||
List of retrieved chunks with relevance scores
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
if not self.load_index():
|
||||
return []
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await self.embedding_client.embed(query)
|
||||
query_vector = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
# Normalize for cosine similarity
|
||||
faiss.normalize_L2(query_vector)
|
||||
|
||||
# Search FAISS (get more candidates than needed for filtering)
|
||||
k_search = min(8, self._index.ntotal)
|
||||
scores, indices = self._index.search(query_vector, k_search)
|
||||
|
||||
# Apply adaptive selection
|
||||
selected = self._adaptive_select(indices[0], scores[0])
|
||||
|
||||
# Apply diversity filter
|
||||
filtered = self._apply_diversity_filter(selected)
|
||||
|
||||
# Build result chunks
|
||||
results = []
|
||||
for idx, score in filtered:
|
||||
chunk_data = self._metadata[idx]
|
||||
results.append(RetrievedChunk(
|
||||
chunk_id=chunk_data["chunk_id"],
|
||||
content=chunk_data["content"],
|
||||
chunk_type=chunk_data["chunk_type"],
|
||||
artifact_file=chunk_data["artifact_file"],
|
||||
source_file=chunk_data["source_file"],
|
||||
tags=chunk_data.get("tags", {}),
|
||||
score=score,
|
||||
))
|
||||
|
||||
logger.debug(f"Retrieved {len(results)} chunks for query")
|
||||
return results
|
||||
|
||||
|
||||
# Singleton retriever instance
|
||||
_retriever: Retriever | None = None
|
||||
|
||||
|
||||
def get_retriever() -> Retriever:
|
||||
"""Get singleton retriever instance (lazily initialized)."""
|
||||
global _retriever
|
||||
if _retriever is None:
|
||||
_retriever = Retriever(embedding_client=get_embedding_client())
|
||||
# Attempt to load index at startup
|
||||
_retriever.load_index()
|
||||
return _retriever
|
||||
|
||||
|
||||
def reset_retriever() -> None:
|
||||
"""Reset the singleton retriever (for reloading after re-indexing)."""
|
||||
global _retriever
|
||||
_retriever = None
|
||||
|
||||
|
||||
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]
|
||||
Reference in New Issue
Block a user