feat: replace Redis with in-memory conversation storage

- Remove Redis dependency and redis_client.py
- Implement ConversationMemory with module-level dictionary
- Add TTL support via timestamp checking
- Remove redis_connected from health endpoint
- Add embeddings, intent classification, and RAG prompt modules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Danny
2026-01-30 10:34:47 -06:00
parent 72778b65b5
commit b0211b944d
13 changed files with 1168 additions and 16 deletions
+18
View File
@@ -0,0 +1,18 @@
"""Embeddings module for YAML artifact indexing and retrieval."""
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingClientDependency
from app.embeddings.indexer import ArtifactIndexer, get_indexer, IndexerDependency
from app.embeddings.retriever import Retriever, get_retriever, RetrieverDependency, reset_retriever
__all__ = [
"EmbeddingClient",
"get_embedding_client",
"EmbeddingClientDependency",
"ArtifactIndexer",
"get_indexer",
"IndexerDependency",
"Retriever",
"get_retriever",
"RetrieverDependency",
"reset_retriever",
]
+109
View File
@@ -0,0 +1,109 @@
"""OpenAI embedding client wrapper."""
import logging
from functools import lru_cache
from typing import Annotated
from fastapi import Depends
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
from app.config import settings
logger = logging.getLogger(__name__)
class EmbeddingError(Exception):
"""Base exception for embedding operations."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(message)
class EmbeddingClient:
"""Async wrapper for OpenAI embeddings API."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""Initialize the embedding client.
Args:
api_key: OpenAI API key
model: Embedding model identifier
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
self.dimensions = 1536 # text-embedding-3-small dimension
async def embed(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding vector (1536 dimensions)
Raises:
EmbeddingError: If embedding generation fails
"""
try:
response = await self.client.embeddings.create(
model=self.model,
input=text,
)
return response.data[0].embedding
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding generation fails
"""
if not texts:
return []
try:
response = await self.client.embeddings.create(
model=self.model,
input=texts,
)
# Sort by index to ensure correct ordering
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_embeddings]
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
@lru_cache()
def get_embedding_client() -> EmbeddingClient:
"""Get cached embedding client instance."""
return EmbeddingClient(
api_key=settings.openai_api_key,
model=settings.embedding_model,
)
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]
+227
View File
@@ -0,0 +1,227 @@
"""YAML artifact indexer for building FAISS index."""
import json
import logging
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
import yaml
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
logger = logging.getLogger(__name__)
@dataclass
class Chunk:
"""Represents a text chunk from a YAML artifact."""
chunk_id: str
content: str
chunk_type: str # factual_summary, interpretive_summary, method, invariants
artifact_file: str
source_file: str
tags: dict
@dataclass
class IndexResult:
"""Result of indexing operation."""
chunks_indexed: int
artifacts_processed: int
status: str
class ArtifactIndexer:
"""Parses YAML artifacts and builds FAISS index."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the indexer.
Args:
embedding_client: Client for generating embeddings
"""
self.embedding_client = embedding_client
self.artifacts_path = Path(settings.artifacts_path)
self.embeddings_path = Path(settings.embeddings_path)
self.dimensions = 1536
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
"""Parse a YAML artifact file into chunks.
Args:
yaml_path: Path to the YAML file
Returns:
List of chunks extracted from the file
"""
chunks = []
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
try:
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except Exception as e:
logger.warning(f"Failed to parse {yaml_path}: {e}")
return chunks
if not data:
return chunks
source_file = data.get("source_file", "unknown")
tags = data.get("tags", {})
# Chunk 1: Factual summary
if factual := data.get("factual_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::factual_summary",
content=f"[{source_file}] Factual summary: {factual.strip()}",
chunk_type="factual_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk 2: Interpretive summary
if interpretive := data.get("interpretive_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::interpretive_summary",
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
chunk_type="interpretive_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk per method
if methods := data.get("methods"):
for method_sig, method_data in methods.items():
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
if description:
chunks.append(Chunk(
chunk_id=f"{artifact_file}::method::{method_sig}",
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
chunk_type="method",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk for invariants (combined)
if invariants := data.get("invariants"):
invariants_text = " ".join(f"- {inv}" for inv in invariants)
chunks.append(Chunk(
chunk_id=f"{artifact_file}::invariants",
content=f"[{source_file}] Invariants: {invariants_text}",
chunk_type="invariants",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
return chunks
def _collect_all_chunks(self) -> list[Chunk]:
"""Collect chunks from all YAML artifacts.
Returns:
List of all chunks from all artifacts
"""
all_chunks = []
for yaml_path in self.artifacts_path.rglob("*.yaml"):
chunks = self._parse_yaml_to_chunks(yaml_path)
all_chunks.extend(chunks)
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
return all_chunks
async def build_index(self) -> IndexResult:
"""Build FAISS index from all YAML artifacts.
Returns:
IndexResult with statistics
Raises:
EmbeddingError: If embedding generation fails
"""
# Collect all chunks
chunks = self._collect_all_chunks()
if not chunks:
logger.warning("No chunks found in artifacts")
return IndexResult(
chunks_indexed=0,
artifacts_processed=0,
status="no_artifacts",
)
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
# Generate embeddings in batches
batch_size = 100
all_embeddings = []
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [chunk.content for chunk in batch]
embeddings = await self.embedding_client.embed_batch(texts)
all_embeddings.extend(embeddings)
logger.debug(f"Embedded batch {i // batch_size + 1}")
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings_array)
index = faiss.IndexFlatIP(self.dimensions)
index.add(embeddings_array)
# Create embeddings directory if needed
self.embeddings_path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
# Save metadata
metadata = {
"chunks": [asdict(chunk) for chunk in chunks],
}
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
# Save index info
artifact_files = set(chunk.artifact_file for chunk in chunks)
index_info = {
"total_chunks": len(chunks),
"total_artifacts": len(artifact_files),
"dimensions": self.dimensions,
"index_type": "IndexFlatIP",
}
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
json.dump(index_info, f, indent=2)
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
return IndexResult(
chunks_indexed=len(chunks),
artifacts_processed=len(artifact_files),
status="completed",
)
@lru_cache()
def get_indexer() -> ArtifactIndexer:
"""Get cached indexer instance."""
return ArtifactIndexer(embedding_client=get_embedding_client())
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]
+221
View File
@@ -0,0 +1,221 @@
"""FAISS-based retrieval with adaptive selection."""
import json
import logging
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client
logger = logging.getLogger(__name__)
@dataclass
class RetrievedChunk:
"""A chunk retrieved from FAISS search."""
chunk_id: str
content: str
chunk_type: str
artifact_file: str
source_file: str
tags: dict
score: float
class Retriever:
"""FAISS-based retriever with adaptive selection logic."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the retriever.
Args:
embedding_client: Client for generating query embeddings
"""
self.embedding_client = embedding_client
self.embeddings_path = Path(settings.embeddings_path)
self.top_k = settings.rag_top_k
self.threshold = settings.rag_similarity_threshold
self._index: faiss.IndexFlatIP | None = None
self._metadata: list[dict] | None = None
self._loaded = False
def load_index(self) -> bool:
"""Load FAISS index and metadata from disk.
Returns:
True if successfully loaded, False otherwise
"""
index_path = self.embeddings_path / "faiss_index.bin"
metadata_path = self.embeddings_path / "metadata.json"
if not index_path.exists() or not metadata_path.exists():
logger.warning("FAISS index or metadata not found. Run /index first.")
self._loaded = False
return False
try:
self._index = faiss.read_index(str(index_path))
with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f)
self._metadata = data.get("chunks", [])
self._loaded = True
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
return True
except Exception as e:
logger.error(f"Failed to load FAISS index: {e}")
self._loaded = False
return False
@property
def is_loaded(self) -> bool:
"""Check if index is loaded."""
return self._loaded and self._index is not None
@property
def index_size(self) -> int:
"""Get number of vectors in index."""
if self._index is None:
return 0
return self._index.ntotal
def _adaptive_select(
self,
indices: np.ndarray,
scores: np.ndarray,
) -> list[tuple[int, float]]:
"""Apply adaptive selection logic.
- Always include top 2 chunks (regardless of score)
- For chunks 3-5: apply threshold
- Limit to self.top_k chunks total
Args:
indices: FAISS result indices
scores: FAISS result scores
Returns:
List of (index, score) tuples for selected chunks
"""
selected = []
for i, (idx, score) in enumerate(zip(indices, scores)):
if idx == -1: # FAISS returns -1 for no match
continue
# Always take top 2
if i < 2:
selected.append((int(idx), float(score)))
# Apply threshold for remaining
elif score >= self.threshold and len(selected) < self.top_k:
selected.append((int(idx), float(score)))
return selected
def _apply_diversity_filter(
self,
candidates: list[tuple[int, float]],
max_per_artifact: int = 2,
) -> list[tuple[int, float]]:
"""Limit chunks per artifact for diversity.
Args:
candidates: List of (index, score) tuples
max_per_artifact: Maximum chunks from same artifact
Returns:
Filtered list of (index, score) tuples
"""
artifact_counts: dict[str, int] = {}
filtered = []
for idx, score in candidates:
chunk = self._metadata[idx]
artifact = chunk["artifact_file"]
if artifact_counts.get(artifact, 0) < max_per_artifact:
filtered.append((idx, score))
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
return filtered
async def search(self, query: str) -> list[RetrievedChunk]:
"""Search for relevant chunks.
Args:
query: User's question
Returns:
List of retrieved chunks with relevance scores
"""
if not self.is_loaded:
if not self.load_index():
return []
# Generate query embedding
query_embedding = await self.embedding_client.embed(query)
query_vector = np.array([query_embedding], dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(query_vector)
# Search FAISS (get more candidates than needed for filtering)
k_search = min(8, self._index.ntotal)
scores, indices = self._index.search(query_vector, k_search)
# Apply adaptive selection
selected = self._adaptive_select(indices[0], scores[0])
# Apply diversity filter
filtered = self._apply_diversity_filter(selected)
# Build result chunks
results = []
for idx, score in filtered:
chunk_data = self._metadata[idx]
results.append(RetrievedChunk(
chunk_id=chunk_data["chunk_id"],
content=chunk_data["content"],
chunk_type=chunk_data["chunk_type"],
artifact_file=chunk_data["artifact_file"],
source_file=chunk_data["source_file"],
tags=chunk_data.get("tags", {}),
score=score,
))
logger.debug(f"Retrieved {len(results)} chunks for query")
return results
# Singleton retriever instance
_retriever: Retriever | None = None
def get_retriever() -> Retriever:
"""Get singleton retriever instance (lazily initialized)."""
global _retriever
if _retriever is None:
_retriever = Retriever(embedding_client=get_embedding_client())
# Attempt to load index at startup
_retriever.load_index()
return _retriever
def reset_retriever() -> None:
"""Reset the singleton retriever (for reloading after re-indexing)."""
global _retriever
_retriever = None
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]