feat: replace Redis with in-memory conversation storage
- Remove Redis dependency and redis_client.py - Implement ConversationMemory with module-level dictionary - Add TTL support via timestamp checking - Remove redis_connected from health endpoint - Add embeddings, intent classification, and RAG prompt modules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
72778b65b5
commit
b0211b944d
|
|
@ -22,6 +22,18 @@ class Settings(BaseSettings):
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
cors_origins: str = "http://localhost:3000"
|
cors_origins: str = "http://localhost:3000"
|
||||||
|
|
||||||
|
# Conversation memory configuration
|
||||||
|
conversation_ttl: int = 86400 # 24 hours
|
||||||
|
|
||||||
|
# Embedding configuration
|
||||||
|
embedding_model: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# RAG configuration
|
||||||
|
rag_top_k: int = 5
|
||||||
|
rag_similarity_threshold: float = 0.70
|
||||||
|
artifacts_path: str = "artifacts"
|
||||||
|
embeddings_path: str = "embeddings"
|
||||||
|
|
||||||
# Authentication settings
|
# Authentication settings
|
||||||
auth_enabled: bool = False # Set to True in production
|
auth_enabled: bool = False # Set to True in production
|
||||||
auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app)
|
auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""Embeddings module for YAML artifact indexing and retrieval."""
|
||||||
|
|
||||||
|
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingClientDependency
|
||||||
|
from app.embeddings.indexer import ArtifactIndexer, get_indexer, IndexerDependency
|
||||||
|
from app.embeddings.retriever import Retriever, get_retriever, RetrieverDependency, reset_retriever
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EmbeddingClient",
|
||||||
|
"get_embedding_client",
|
||||||
|
"EmbeddingClientDependency",
|
||||||
|
"ArtifactIndexer",
|
||||||
|
"get_indexer",
|
||||||
|
"IndexerDependency",
|
||||||
|
"Retriever",
|
||||||
|
"get_retriever",
|
||||||
|
"RetrieverDependency",
|
||||||
|
"reset_retriever",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""OpenAI embedding client wrapper."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingError(Exception):
|
||||||
|
"""Base exception for embedding operations."""
|
||||||
|
|
||||||
|
def __init__(self, message: str, status_code: int = 500):
|
||||||
|
self.message = message
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingClient:
|
||||||
|
"""Async wrapper for OpenAI embeddings API."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
||||||
|
"""Initialize the embedding client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
model: Embedding model identifier
|
||||||
|
"""
|
||||||
|
self.client = AsyncOpenAI(api_key=api_key)
|
||||||
|
self.model = model
|
||||||
|
self.dimensions = 1536 # text-embedding-3-small dimension
|
||||||
|
|
||||||
|
async def embed(self, text: str) -> list[float]:
|
||||||
|
"""Generate embedding for a single text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector (1536 dimensions)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbeddingError: If embedding generation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
except AuthenticationError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
||||||
|
except APIError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
||||||
|
|
||||||
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
"""Generate embeddings for multiple texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbeddingError: If embedding generation fails
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=texts,
|
||||||
|
)
|
||||||
|
# Sort by index to ensure correct ordering
|
||||||
|
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
|
||||||
|
return [item.embedding for item in sorted_embeddings]
|
||||||
|
|
||||||
|
except AuthenticationError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
||||||
|
except APIError as e:
|
||||||
|
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_embedding_client() -> EmbeddingClient:
|
||||||
|
"""Get cached embedding client instance."""
|
||||||
|
return EmbeddingClient(
|
||||||
|
api_key=settings.openai_api_key,
|
||||||
|
model=settings.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]
|
||||||
|
|
@ -0,0 +1,227 @@
|
||||||
|
"""YAML artifact indexer for building FAISS index."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
"""Represents a text chunk from a YAML artifact."""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
chunk_type: str # factual_summary, interpretive_summary, method, invariants
|
||||||
|
artifact_file: str
|
||||||
|
source_file: str
|
||||||
|
tags: dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexResult:
|
||||||
|
"""Result of indexing operation."""
|
||||||
|
|
||||||
|
chunks_indexed: int
|
||||||
|
artifacts_processed: int
|
||||||
|
status: str
|
||||||
|
|
||||||
|
|
||||||
|
class ArtifactIndexer:
|
||||||
|
"""Parses YAML artifacts and builds FAISS index."""
|
||||||
|
|
||||||
|
def __init__(self, embedding_client: EmbeddingClient):
|
||||||
|
"""Initialize the indexer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_client: Client for generating embeddings
|
||||||
|
"""
|
||||||
|
self.embedding_client = embedding_client
|
||||||
|
self.artifacts_path = Path(settings.artifacts_path)
|
||||||
|
self.embeddings_path = Path(settings.embeddings_path)
|
||||||
|
self.dimensions = 1536
|
||||||
|
|
||||||
|
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
|
||||||
|
"""Parse a YAML artifact file into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_path: Path to the YAML file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunks extracted from the file
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse {yaml_path}: {e}")
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
source_file = data.get("source_file", "unknown")
|
||||||
|
tags = data.get("tags", {})
|
||||||
|
|
||||||
|
# Chunk 1: Factual summary
|
||||||
|
if factual := data.get("factual_summary"):
|
||||||
|
chunks.append(Chunk(
|
||||||
|
chunk_id=f"{artifact_file}::factual_summary",
|
||||||
|
content=f"[{source_file}] Factual summary: {factual.strip()}",
|
||||||
|
chunk_type="factual_summary",
|
||||||
|
artifact_file=artifact_file,
|
||||||
|
source_file=source_file,
|
||||||
|
tags=tags,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Chunk 2: Interpretive summary
|
||||||
|
if interpretive := data.get("interpretive_summary"):
|
||||||
|
chunks.append(Chunk(
|
||||||
|
chunk_id=f"{artifact_file}::interpretive_summary",
|
||||||
|
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
|
||||||
|
chunk_type="interpretive_summary",
|
||||||
|
artifact_file=artifact_file,
|
||||||
|
source_file=source_file,
|
||||||
|
tags=tags,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Chunk per method
|
||||||
|
if methods := data.get("methods"):
|
||||||
|
for method_sig, method_data in methods.items():
|
||||||
|
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
|
||||||
|
if description:
|
||||||
|
chunks.append(Chunk(
|
||||||
|
chunk_id=f"{artifact_file}::method::{method_sig}",
|
||||||
|
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
|
||||||
|
chunk_type="method",
|
||||||
|
artifact_file=artifact_file,
|
||||||
|
source_file=source_file,
|
||||||
|
tags=tags,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Chunk for invariants (combined)
|
||||||
|
if invariants := data.get("invariants"):
|
||||||
|
invariants_text = " ".join(f"- {inv}" for inv in invariants)
|
||||||
|
chunks.append(Chunk(
|
||||||
|
chunk_id=f"{artifact_file}::invariants",
|
||||||
|
content=f"[{source_file}] Invariants: {invariants_text}",
|
||||||
|
chunk_type="invariants",
|
||||||
|
artifact_file=artifact_file,
|
||||||
|
source_file=source_file,
|
||||||
|
tags=tags,
|
||||||
|
))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _collect_all_chunks(self) -> list[Chunk]:
|
||||||
|
"""Collect chunks from all YAML artifacts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all chunks from all artifacts
|
||||||
|
"""
|
||||||
|
all_chunks = []
|
||||||
|
|
||||||
|
for yaml_path in self.artifacts_path.rglob("*.yaml"):
|
||||||
|
chunks = self._parse_yaml_to_chunks(yaml_path)
|
||||||
|
all_chunks.extend(chunks)
|
||||||
|
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
async def build_index(self) -> IndexResult:
|
||||||
|
"""Build FAISS index from all YAML artifacts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
IndexResult with statistics
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
EmbeddingError: If embedding generation fails
|
||||||
|
"""
|
||||||
|
# Collect all chunks
|
||||||
|
chunks = self._collect_all_chunks()
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
logger.warning("No chunks found in artifacts")
|
||||||
|
return IndexResult(
|
||||||
|
chunks_indexed=0,
|
||||||
|
artifacts_processed=0,
|
||||||
|
status="no_artifacts",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
|
||||||
|
|
||||||
|
# Generate embeddings in batches
|
||||||
|
batch_size = 100
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for i in range(0, len(chunks), batch_size):
|
||||||
|
batch = chunks[i:i + batch_size]
|
||||||
|
texts = [chunk.content for chunk in batch]
|
||||||
|
embeddings = await self.embedding_client.embed_batch(texts)
|
||||||
|
all_embeddings.extend(embeddings)
|
||||||
|
logger.debug(f"Embedded batch {i // batch_size + 1}")
|
||||||
|
|
||||||
|
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
|
||||||
|
embeddings_array = np.array(all_embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize for cosine similarity
|
||||||
|
faiss.normalize_L2(embeddings_array)
|
||||||
|
|
||||||
|
index = faiss.IndexFlatIP(self.dimensions)
|
||||||
|
index.add(embeddings_array)
|
||||||
|
|
||||||
|
# Create embeddings directory if needed
|
||||||
|
self.embeddings_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save FAISS index
|
||||||
|
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
metadata = {
|
||||||
|
"chunks": [asdict(chunk) for chunk in chunks],
|
||||||
|
}
|
||||||
|
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
# Save index info
|
||||||
|
artifact_files = set(chunk.artifact_file for chunk in chunks)
|
||||||
|
index_info = {
|
||||||
|
"total_chunks": len(chunks),
|
||||||
|
"total_artifacts": len(artifact_files),
|
||||||
|
"dimensions": self.dimensions,
|
||||||
|
"index_type": "IndexFlatIP",
|
||||||
|
}
|
||||||
|
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
|
||||||
|
json.dump(index_info, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
|
||||||
|
|
||||||
|
return IndexResult(
|
||||||
|
chunks_indexed=len(chunks),
|
||||||
|
artifacts_processed=len(artifact_files),
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_indexer() -> ArtifactIndexer:
|
||||||
|
"""Get cached indexer instance."""
|
||||||
|
return ArtifactIndexer(embedding_client=get_embedding_client())
|
||||||
|
|
||||||
|
|
||||||
|
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]
|
||||||
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""FAISS-based retrieval with adaptive selection."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.embeddings.client import EmbeddingClient, get_embedding_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievedChunk:
|
||||||
|
"""A chunk retrieved from FAISS search."""
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
chunk_type: str
|
||||||
|
artifact_file: str
|
||||||
|
source_file: str
|
||||||
|
tags: dict
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class Retriever:
|
||||||
|
"""FAISS-based retriever with adaptive selection logic."""
|
||||||
|
|
||||||
|
def __init__(self, embedding_client: EmbeddingClient):
|
||||||
|
"""Initialize the retriever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_client: Client for generating query embeddings
|
||||||
|
"""
|
||||||
|
self.embedding_client = embedding_client
|
||||||
|
self.embeddings_path = Path(settings.embeddings_path)
|
||||||
|
self.top_k = settings.rag_top_k
|
||||||
|
self.threshold = settings.rag_similarity_threshold
|
||||||
|
|
||||||
|
self._index: faiss.IndexFlatIP | None = None
|
||||||
|
self._metadata: list[dict] | None = None
|
||||||
|
self._loaded = False
|
||||||
|
|
||||||
|
def load_index(self) -> bool:
|
||||||
|
"""Load FAISS index and metadata from disk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successfully loaded, False otherwise
|
||||||
|
"""
|
||||||
|
index_path = self.embeddings_path / "faiss_index.bin"
|
||||||
|
metadata_path = self.embeddings_path / "metadata.json"
|
||||||
|
|
||||||
|
if not index_path.exists() or not metadata_path.exists():
|
||||||
|
logger.warning("FAISS index or metadata not found. Run /index first.")
|
||||||
|
self._loaded = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._index = faiss.read_index(str(index_path))
|
||||||
|
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self._metadata = data.get("chunks", [])
|
||||||
|
|
||||||
|
self._loaded = True
|
||||||
|
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load FAISS index: {e}")
|
||||||
|
self._loaded = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if index is loaded."""
|
||||||
|
return self._loaded and self._index is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index_size(self) -> int:
|
||||||
|
"""Get number of vectors in index."""
|
||||||
|
if self._index is None:
|
||||||
|
return 0
|
||||||
|
return self._index.ntotal
|
||||||
|
|
||||||
|
def _adaptive_select(
|
||||||
|
self,
|
||||||
|
indices: np.ndarray,
|
||||||
|
scores: np.ndarray,
|
||||||
|
) -> list[tuple[int, float]]:
|
||||||
|
"""Apply adaptive selection logic.
|
||||||
|
|
||||||
|
- Always include top 2 chunks (regardless of score)
|
||||||
|
- For chunks 3-5: apply threshold
|
||||||
|
- Limit to self.top_k chunks total
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices: FAISS result indices
|
||||||
|
scores: FAISS result scores
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (index, score) tuples for selected chunks
|
||||||
|
"""
|
||||||
|
selected = []
|
||||||
|
|
||||||
|
for i, (idx, score) in enumerate(zip(indices, scores)):
|
||||||
|
if idx == -1: # FAISS returns -1 for no match
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Always take top 2
|
||||||
|
if i < 2:
|
||||||
|
selected.append((int(idx), float(score)))
|
||||||
|
# Apply threshold for remaining
|
||||||
|
elif score >= self.threshold and len(selected) < self.top_k:
|
||||||
|
selected.append((int(idx), float(score)))
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def _apply_diversity_filter(
|
||||||
|
self,
|
||||||
|
candidates: list[tuple[int, float]],
|
||||||
|
max_per_artifact: int = 2,
|
||||||
|
) -> list[tuple[int, float]]:
|
||||||
|
"""Limit chunks per artifact for diversity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: List of (index, score) tuples
|
||||||
|
max_per_artifact: Maximum chunks from same artifact
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of (index, score) tuples
|
||||||
|
"""
|
||||||
|
artifact_counts: dict[str, int] = {}
|
||||||
|
filtered = []
|
||||||
|
|
||||||
|
for idx, score in candidates:
|
||||||
|
chunk = self._metadata[idx]
|
||||||
|
artifact = chunk["artifact_file"]
|
||||||
|
|
||||||
|
if artifact_counts.get(artifact, 0) < max_per_artifact:
|
||||||
|
filtered.append((idx, score))
|
||||||
|
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
async def search(self, query: str) -> list[RetrievedChunk]:
|
||||||
|
"""Search for relevant chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User's question
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of retrieved chunks with relevance scores
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
if not self.load_index():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Generate query embedding
|
||||||
|
query_embedding = await self.embedding_client.embed(query)
|
||||||
|
query_vector = np.array([query_embedding], dtype=np.float32)
|
||||||
|
|
||||||
|
# Normalize for cosine similarity
|
||||||
|
faiss.normalize_L2(query_vector)
|
||||||
|
|
||||||
|
# Search FAISS (get more candidates than needed for filtering)
|
||||||
|
k_search = min(8, self._index.ntotal)
|
||||||
|
scores, indices = self._index.search(query_vector, k_search)
|
||||||
|
|
||||||
|
# Apply adaptive selection
|
||||||
|
selected = self._adaptive_select(indices[0], scores[0])
|
||||||
|
|
||||||
|
# Apply diversity filter
|
||||||
|
filtered = self._apply_diversity_filter(selected)
|
||||||
|
|
||||||
|
# Build result chunks
|
||||||
|
results = []
|
||||||
|
for idx, score in filtered:
|
||||||
|
chunk_data = self._metadata[idx]
|
||||||
|
results.append(RetrievedChunk(
|
||||||
|
chunk_id=chunk_data["chunk_id"],
|
||||||
|
content=chunk_data["content"],
|
||||||
|
chunk_type=chunk_data["chunk_type"],
|
||||||
|
artifact_file=chunk_data["artifact_file"],
|
||||||
|
source_file=chunk_data["source_file"],
|
||||||
|
tags=chunk_data.get("tags", {}),
|
||||||
|
score=score,
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.debug(f"Retrieved {len(results)} chunks for query")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton retriever instance
|
||||||
|
_retriever: Retriever | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_retriever() -> Retriever:
|
||||||
|
"""Get singleton retriever instance (lazily initialized)."""
|
||||||
|
global _retriever
|
||||||
|
if _retriever is None:
|
||||||
|
_retriever = Retriever(embedding_client=get_embedding_client())
|
||||||
|
# Attempt to load index at startup
|
||||||
|
_retriever.load_index()
|
||||||
|
return _retriever
|
||||||
|
|
||||||
|
|
||||||
|
def reset_retriever() -> None:
|
||||||
|
"""Reset the singleton retriever (for reloading after re-indexing)."""
|
||||||
|
global _retriever
|
||||||
|
_retriever = None
|
||||||
|
|
||||||
|
|
||||||
|
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Intent classification module."""
|
||||||
|
|
||||||
|
from app.intent.classifier import IntentClassifier, get_intent_classifier, IntentClassifierDependency
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"IntentClassifier",
|
||||||
|
"get_intent_classifier",
|
||||||
|
"IntentClassifierDependency",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""Lightweight intent classification using gpt-4o-mini."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
from app.memory.conversation import Message
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Intent = Literal["codebase", "general", "clarification"]
|
||||||
|
|
||||||
|
INTENT_PROMPT = """Classify this user message into one category:
|
||||||
|
- "codebase": Questions about trading system code, architecture, files, methods, execution, strategies, exchanges, risk management, order handling, or technical implementation
|
||||||
|
- "general": Greetings, meta-questions, off-topic ("How are you?", "What can you do?", "Hello")
|
||||||
|
- "clarification": Follow-ups that rely on conversation context, not new retrieval ("Tell me more", "What did you mean?", "Can you explain that?")
|
||||||
|
|
||||||
|
IMPORTANT: If the user is asking about specific code, files, classes, methods, or system behavior, classify as "codebase".
|
||||||
|
|
||||||
|
Respond with ONLY the category name, nothing else."""
|
||||||
|
|
||||||
|
|
||||||
|
class IntentClassifier:
|
||||||
|
"""Lightweight intent classifier using gpt-4o-mini."""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str):
|
||||||
|
"""Initialize the classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenAI API key
|
||||||
|
"""
|
||||||
|
self.client = AsyncOpenAI(api_key=api_key)
|
||||||
|
self.model = "gpt-4o-mini"
|
||||||
|
|
||||||
|
async def classify(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
history: list[Message] | None = None,
|
||||||
|
) -> Intent:
|
||||||
|
"""Classify user message intent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: User's message
|
||||||
|
history: Optional conversation history for context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Classified intent: "codebase", "general", or "clarification"
|
||||||
|
"""
|
||||||
|
# Build context from history (last 2 turns)
|
||||||
|
context = ""
|
||||||
|
if history and len(history) >= 2:
|
||||||
|
recent = history[-4:] # Last 2 exchanges
|
||||||
|
context = "Recent conversation:\n"
|
||||||
|
for msg in recent:
|
||||||
|
role = "User" if msg.role == "user" else "Assistant"
|
||||||
|
context += f"{role}: {msg.content[:100]}...\n" if len(msg.content) > 100 else f"{role}: {msg.content}\n"
|
||||||
|
context += "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": INTENT_PROMPT},
|
||||||
|
{"role": "user", "content": f"{context}Current message: {message}"},
|
||||||
|
],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_intent = response.choices[0].message.content.strip().lower()
|
||||||
|
|
||||||
|
# Validate intent
|
||||||
|
if raw_intent in ("codebase", "general", "clarification"):
|
||||||
|
logger.debug(f"Classified intent: {raw_intent}")
|
||||||
|
return raw_intent
|
||||||
|
|
||||||
|
# Default to codebase for ambiguous cases (safer for RAG)
|
||||||
|
logger.warning(f"Unexpected intent response: {raw_intent}, defaulting to codebase")
|
||||||
|
return "codebase"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Intent classification failed: {e}, defaulting to codebase")
|
||||||
|
return "codebase"
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_intent_classifier() -> IntentClassifier:
|
||||||
|
"""Get cached intent classifier instance."""
|
||||||
|
return IntentClassifier(api_key=settings.openai_api_key)
|
||||||
|
|
||||||
|
|
||||||
|
IntentClassifierDependency = Annotated[IntentClassifier, Depends(get_intent_classifier)]
|
||||||
196
app/main.py
196
app/main.py
|
|
@ -1,17 +1,24 @@
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.auth import ServiceAuthDependency
|
from app.auth import ServiceAuthDependency
|
||||||
from app.config import settings, MAX_MESSAGE_LENGTH
|
from app.config import settings, MAX_MESSAGE_LENGTH
|
||||||
|
from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever
|
||||||
|
from app.embeddings.retriever import RetrievedChunk
|
||||||
|
from app.intent import IntentClassifierDependency
|
||||||
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
||||||
|
from app.memory import ConversationMemoryDependency
|
||||||
|
from app.prompts import build_rag_prompt
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
HealthResponse,
|
HealthResponse,
|
||||||
|
IndexResponse,
|
||||||
|
SourceReference,
|
||||||
StreamChunkEvent,
|
StreamChunkEvent,
|
||||||
StreamDoneEvent,
|
StreamDoneEvent,
|
||||||
StreamErrorEvent,
|
StreamErrorEvent,
|
||||||
|
|
@ -43,21 +50,101 @@ app.add_middleware(
|
||||||
|
|
||||||
@app.get("/health", response_model=HealthResponse)
|
@app.get("/health", response_model=HealthResponse)
|
||||||
async def health_check() -> HealthResponse:
|
async def health_check() -> HealthResponse:
|
||||||
"""Health check endpoint."""
|
"""Health check endpoint with FAISS status."""
|
||||||
return HealthResponse(status="ok")
|
retriever = get_retriever()
|
||||||
|
|
||||||
|
return HealthResponse(
|
||||||
|
status="ok",
|
||||||
|
faiss_loaded=retriever.is_loaded,
|
||||||
|
index_size=retriever.index_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/debug/search")
|
||||||
|
async def debug_search(
|
||||||
|
q: str,
|
||||||
|
retriever: RetrieverDependency,
|
||||||
|
) -> dict:
|
||||||
|
"""Debug endpoint to test retrieval directly."""
|
||||||
|
chunks = await retriever.search(q)
|
||||||
|
return {
|
||||||
|
"query": q,
|
||||||
|
"chunks_found": len(chunks),
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"source_file": c.source_file,
|
||||||
|
"chunk_type": c.chunk_type,
|
||||||
|
"score": round(c.score, 4),
|
||||||
|
"content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content,
|
||||||
|
}
|
||||||
|
for c in chunks
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/index", response_model=IndexResponse)
|
||||||
|
async def reindex_artifacts(
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
indexer: IndexerDependency,
|
||||||
|
_auth: ServiceAuthDependency,
|
||||||
|
) -> IndexResponse:
|
||||||
|
"""Trigger re-indexing of YAML artifacts.
|
||||||
|
|
||||||
|
Builds FAISS index from all YAML files in the artifacts directory.
|
||||||
|
"""
|
||||||
|
logger.info("Starting artifact indexing...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await indexer.build_index()
|
||||||
|
|
||||||
|
# Reset retriever to reload new index
|
||||||
|
reset_retriever()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts"
|
||||||
|
)
|
||||||
|
|
||||||
|
return IndexResponse(
|
||||||
|
status=result.status,
|
||||||
|
chunks_indexed=result.chunks_indexed,
|
||||||
|
artifacts_processed=result.artifacts_processed,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Indexing failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]:
|
||||||
|
"""Convert retrieved chunks to source references."""
|
||||||
|
return [
|
||||||
|
SourceReference(
|
||||||
|
artifact_file=chunk.artifact_file,
|
||||||
|
source_file=chunk.source_file,
|
||||||
|
chunk_type=chunk.chunk_type,
|
||||||
|
relevance_score=chunk.score,
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat", response_model=ChatResponse)
|
@app.post("/chat", response_model=ChatResponse)
|
||||||
async def chat(
|
async def chat(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
adapter: AdapterDependency,
|
adapter: AdapterDependency,
|
||||||
|
retriever: RetrieverDependency,
|
||||||
|
memory: ConversationMemoryDependency,
|
||||||
|
classifier: IntentClassifierDependency,
|
||||||
_auth: ServiceAuthDependency,
|
_auth: ServiceAuthDependency,
|
||||||
) -> ChatResponse:
|
) -> ChatResponse:
|
||||||
"""Process a chat message through the LLM adapter.
|
"""Process a chat message through the RAG pipeline.
|
||||||
|
|
||||||
- Validates message length
|
- Validates message length
|
||||||
- Generates conversation_id if not provided
|
- Generates conversation_id if not provided
|
||||||
- Routes to appropriate LLM adapter based on LLM_MODE
|
- Classifies intent (codebase/general/clarification)
|
||||||
|
- Retrieves relevant context from FAISS (for codebase intent)
|
||||||
|
- Builds RAG prompt and generates response
|
||||||
|
- Stores conversation turn in Redis
|
||||||
"""
|
"""
|
||||||
# Validate message length
|
# Validate message length
|
||||||
if len(request.message) > MAX_MESSAGE_LENGTH:
|
if len(request.message) > MAX_MESSAGE_LENGTH:
|
||||||
|
|
@ -70,19 +157,42 @@ async def chat(
|
||||||
# Generate or use provided conversation_id
|
# Generate or use provided conversation_id
|
||||||
conversation_id = request.conversation_id or str(uuid.uuid4())
|
conversation_id = request.conversation_id or str(uuid.uuid4())
|
||||||
|
|
||||||
# Log request metadata (not content)
|
# Get conversation history
|
||||||
|
history = await memory.get_history(conversation_id)
|
||||||
|
|
||||||
|
# Classify intent
|
||||||
|
intent = await classifier.classify(request.message, history)
|
||||||
|
|
||||||
|
# Log request metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat request received",
|
"Chat request received",
|
||||||
extra={
|
extra={
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"message_length": len(request.message),
|
"message_length": len(request.message),
|
||||||
"mode": settings.llm_mode,
|
"mode": settings.llm_mode,
|
||||||
|
"intent": intent,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Retrieve context for codebase questions
|
||||||
|
chunks: list[RetrievedChunk] = []
|
||||||
|
if intent == "codebase":
|
||||||
|
chunks = await retriever.search(request.message)
|
||||||
|
logger.debug(f"Retrieved {len(chunks)} chunks for codebase question")
|
||||||
|
|
||||||
|
# Build RAG prompt
|
||||||
|
system_prompt, user_content = build_rag_prompt(
|
||||||
|
user_message=request.message,
|
||||||
|
intent=intent,
|
||||||
|
chunks=chunks,
|
||||||
|
history=history,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate response with exception handling
|
# Generate response with exception handling
|
||||||
try:
|
try:
|
||||||
response_text = await adapter.generate(conversation_id, request.message)
|
# For RAG, we pass the full constructed prompt
|
||||||
|
full_prompt = f"{system_prompt}\n\n{user_content}"
|
||||||
|
response_text = await adapter.generate(conversation_id, full_prompt)
|
||||||
except LLMError as e:
|
except LLMError as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"LLM generation failed",
|
"LLM generation failed",
|
||||||
|
|
@ -94,6 +204,15 @@ async def chat(
|
||||||
)
|
)
|
||||||
raise llm_exception_to_http(e)
|
raise llm_exception_to_http(e)
|
||||||
|
|
||||||
|
# Store conversation turn
|
||||||
|
sources = _chunks_to_sources(chunks)
|
||||||
|
await memory.store_turn(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_message=request.message,
|
||||||
|
assistant_message=response_text,
|
||||||
|
sources=[s.model_dump() for s in sources] if sources else None,
|
||||||
|
)
|
||||||
|
|
||||||
# Log response metadata
|
# Log response metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat response generated",
|
"Chat response generated",
|
||||||
|
|
@ -101,6 +220,8 @@ async def chat(
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"response_length": len(response_text),
|
"response_length": len(response_text),
|
||||||
"mode": settings.llm_mode,
|
"mode": settings.llm_mode,
|
||||||
|
"intent": intent,
|
||||||
|
"sources_count": len(sources),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -108,7 +229,8 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
response=response_text,
|
response=response_text,
|
||||||
mode=settings.llm_mode,
|
mode=settings.llm_mode,
|
||||||
sources=[],
|
intent=intent,
|
||||||
|
sources=sources,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -116,14 +238,18 @@ async def chat(
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
request: ChatRequest,
|
request: ChatRequest,
|
||||||
adapter: AdapterDependency,
|
adapter: AdapterDependency,
|
||||||
|
retriever: RetrieverDependency,
|
||||||
|
memory: ConversationMemoryDependency,
|
||||||
|
classifier: IntentClassifierDependency,
|
||||||
_auth: ServiceAuthDependency,
|
_auth: ServiceAuthDependency,
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Stream a chat response through the LLM adapter using Server-Sent Events.
|
"""Stream a chat response through the RAG pipeline using Server-Sent Events.
|
||||||
|
|
||||||
- Validates message length
|
- Validates message length
|
||||||
- Generates conversation_id if not provided
|
- Generates conversation_id if not provided
|
||||||
- Routes to appropriate LLM adapter based on LLM_MODE
|
- Classifies intent and retrieves context
|
||||||
- Returns streaming response with SSE format
|
- Streams response with SSE format
|
||||||
|
- Stores conversation turn after completion
|
||||||
"""
|
"""
|
||||||
# Validate message length
|
# Validate message length
|
||||||
if len(request.message) > MAX_MESSAGE_LENGTH:
|
if len(request.message) > MAX_MESSAGE_LENGTH:
|
||||||
|
|
@ -136,26 +262,64 @@ async def chat_stream(
|
||||||
# Generate or use provided conversation_id
|
# Generate or use provided conversation_id
|
||||||
conversation_id = request.conversation_id or str(uuid.uuid4())
|
conversation_id = request.conversation_id or str(uuid.uuid4())
|
||||||
|
|
||||||
# Log request metadata (not content)
|
# Get conversation history
|
||||||
|
history = await memory.get_history(conversation_id)
|
||||||
|
|
||||||
|
# Classify intent
|
||||||
|
intent = await classifier.classify(request.message, history)
|
||||||
|
|
||||||
|
# Retrieve context for codebase questions
|
||||||
|
chunks: list[RetrievedChunk] = []
|
||||||
|
if intent == "codebase":
|
||||||
|
chunks = await retriever.search(request.message)
|
||||||
|
|
||||||
|
# Build RAG prompt
|
||||||
|
system_prompt, user_content = build_rag_prompt(
|
||||||
|
user_message=request.message,
|
||||||
|
intent=intent,
|
||||||
|
chunks=chunks,
|
||||||
|
history=history,
|
||||||
|
)
|
||||||
|
|
||||||
|
sources = _chunks_to_sources(chunks)
|
||||||
|
|
||||||
|
# Log request metadata
|
||||||
logger.info(
|
logger.info(
|
||||||
"Chat stream request received",
|
"Chat stream request received",
|
||||||
extra={
|
extra={
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"message_length": len(request.message),
|
"message_length": len(request.message),
|
||||||
"mode": settings.llm_mode,
|
"mode": settings.llm_mode,
|
||||||
|
"intent": intent,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_response():
|
async def stream_response():
|
||||||
"""Async generator that yields SSE-formatted events."""
|
"""Async generator that yields SSE-formatted events."""
|
||||||
|
full_response = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in adapter.generate_stream(conversation_id, request.message):
|
full_prompt = f"{system_prompt}\n\n{user_content}"
|
||||||
|
async for chunk in adapter.generate_stream(conversation_id, full_prompt):
|
||||||
|
full_response.append(chunk)
|
||||||
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
|
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
|
||||||
yield f"data: {event.model_dump_json()}\n\n"
|
yield f"data: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
# Send completion event
|
# Store conversation turn
|
||||||
|
response_text = "".join(full_response)
|
||||||
|
await memory.store_turn(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_message=request.message,
|
||||||
|
assistant_message=response_text,
|
||||||
|
sources=[s.model_dump() for s in sources] if sources else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send completion event with sources
|
||||||
done_event = StreamDoneEvent(
|
done_event = StreamDoneEvent(
|
||||||
conversation_id=conversation_id, mode=settings.llm_mode
|
conversation_id=conversation_id,
|
||||||
|
mode=settings.llm_mode,
|
||||||
|
intent=intent,
|
||||||
|
sources=sources,
|
||||||
)
|
)
|
||||||
yield f"data: {done_event.model_dump_json()}\n\n"
|
yield f"data: {done_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
|
@ -164,6 +328,8 @@ async def chat_stream(
|
||||||
extra={
|
extra={
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"mode": settings.llm_mode,
|
"mode": settings.llm_mode,
|
||||||
|
"intent": intent,
|
||||||
|
"sources_count": len(sources),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""Memory module for conversation history management."""
|
||||||
|
|
||||||
|
from app.memory.conversation import ConversationMemory, get_conversation_memory, ConversationMemoryDependency
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConversationMemory",
|
||||||
|
"get_conversation_memory",
|
||||||
|
"ConversationMemoryDependency",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""Conversation history management with in-memory storage."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_HISTORY_MESSAGES = 20
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
"""A single message in conversation history."""
|
||||||
|
|
||||||
|
role: str # "user" or "assistant"
|
||||||
|
content: str
|
||||||
|
sources: list[dict] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConversationData:
|
||||||
|
"""Container for conversation messages with timestamp for TTL."""
|
||||||
|
|
||||||
|
messages: list[Message] = field(default_factory=list)
|
||||||
|
last_updated: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level storage for conversations
|
||||||
|
_conversations: dict[str, ConversationData] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationMemory:
|
||||||
|
"""Manages conversation history in memory."""
|
||||||
|
|
||||||
|
def __init__(self, ttl: int):
|
||||||
|
"""Initialize conversation memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl: Time-to-live in seconds for conversations
|
||||||
|
"""
|
||||||
|
self.ttl = ttl
|
||||||
|
|
||||||
|
async def get_history(self, conversation_id: str) -> list[Message]:
|
||||||
|
"""Get conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Conversation identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages in chronological order, or empty list if expired/not found
|
||||||
|
"""
|
||||||
|
data = _conversations.get(conversation_id)
|
||||||
|
if data is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Check if expired
|
||||||
|
if time.time() - data.last_updated > self.ttl:
|
||||||
|
del _conversations[conversation_id]
|
||||||
|
return []
|
||||||
|
|
||||||
|
return data.messages
|
||||||
|
|
||||||
|
async def store_turn(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
sources: list[dict] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a conversation turn (user message + assistant response).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Conversation identifier
|
||||||
|
user_message: User's message
|
||||||
|
assistant_message: Assistant's response
|
||||||
|
sources: Optional source references used in response
|
||||||
|
"""
|
||||||
|
# Get existing history (checks TTL)
|
||||||
|
history = await self.get_history(conversation_id)
|
||||||
|
|
||||||
|
# Add new messages
|
||||||
|
history.append(Message(role="user", content=user_message))
|
||||||
|
history.append(Message(role="assistant", content=assistant_message, sources=sources))
|
||||||
|
|
||||||
|
# Trim to max size (keep most recent)
|
||||||
|
if len(history) > MAX_HISTORY_MESSAGES:
|
||||||
|
history = history[-MAX_HISTORY_MESSAGES:]
|
||||||
|
|
||||||
|
# Store with updated timestamp
|
||||||
|
_conversations[conversation_id] = ConversationData(
|
||||||
|
messages=history,
|
||||||
|
last_updated=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Stored conversation turn for {conversation_id}")
|
||||||
|
|
||||||
|
async def clear(self, conversation_id: str) -> bool:
|
||||||
|
"""Clear conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: Conversation identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cleared successfully
|
||||||
|
"""
|
||||||
|
if conversation_id in _conversations:
|
||||||
|
del _conversations[conversation_id]
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_conversation_memory() -> ConversationMemory:
|
||||||
|
"""Get conversation memory instance."""
|
||||||
|
return ConversationMemory(ttl=settings.conversation_ttl)
|
||||||
|
|
||||||
|
|
||||||
|
ConversationMemoryDependency = Annotated[ConversationMemory, Depends(get_conversation_memory)]
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""System prompts for RAG-based codebase Q&A."""
|
||||||
|
|
||||||
|
from app.embeddings.retriever import RetrievedChunk
|
||||||
|
from app.memory.conversation import Message
|
||||||
|
|
||||||
|
# For codebase questions WITH retrieved context
|
||||||
|
CODEBASE_SYSTEM_PROMPT = """You answer questions about the Tyndale trading system using ONLY the provided YAML artifacts.
|
||||||
|
|
||||||
|
HARD CONSTRAINTS:
|
||||||
|
- Do NOT assume access to source code
|
||||||
|
- Do NOT invent implementation details not in the artifacts
|
||||||
|
- Do NOT speculate about code mechanics beyond what artifacts describe
|
||||||
|
- If artifacts do not contain enough information, say so explicitly
|
||||||
|
|
||||||
|
RESPONSE STYLE:
|
||||||
|
- Prefer architectural and behavioral explanations over mechanics
|
||||||
|
- Reference source files by path (e.g., ./trader.py)
|
||||||
|
- Explain trading concepts for developers without finance background
|
||||||
|
- Keep responses focused and concise"""
|
||||||
|
|
||||||
|
# For general questions (no RAG context)
|
||||||
|
GENERAL_SYSTEM_PROMPT = """You are an assistant for the Tyndale trading system documentation.
|
||||||
|
You can answer general questions, but for specific codebase questions, you need artifact context.
|
||||||
|
If the user asks about specific code without context, ask them to rephrase or be more specific."""
|
||||||
|
|
||||||
|
# For clarification/follow-ups (uses conversation history)
|
||||||
|
CLARIFICATION_SYSTEM_PROMPT = """You are continuing a conversation about the Tyndale trading system.
|
||||||
|
Use the conversation history to answer follow-up questions.
|
||||||
|
If you need to look up new information, ask the user to rephrase as a standalone question."""
|
||||||
|
|
||||||
|
|
||||||
|
def select_system_prompt(intent: str, has_context: bool) -> str:
|
||||||
|
"""Select appropriate system prompt based on intent and context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intent: Classified intent (codebase, general, clarification)
|
||||||
|
has_context: Whether RAG context was retrieved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
System prompt string
|
||||||
|
"""
|
||||||
|
if intent == "codebase" and has_context:
|
||||||
|
return CODEBASE_SYSTEM_PROMPT
|
||||||
|
elif intent == "codebase" and not has_context:
|
||||||
|
return CODEBASE_SYSTEM_PROMPT + "\n\nNOTE: No relevant artifacts were found for this question. Acknowledge this limitation in your response."
|
||||||
|
elif intent == "clarification":
|
||||||
|
return CLARIFICATION_SYSTEM_PROMPT
|
||||||
|
else:
|
||||||
|
return GENERAL_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
|
||||||
|
def format_context(chunks: list[RetrievedChunk]) -> str:
|
||||||
|
"""Format retrieved chunks as context for the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunks: List of retrieved chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context_parts = ["## Retrieved Artifact Context\n"]
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks, 1):
|
||||||
|
context_parts.append(f"### Source {i}: {chunk.source_file} ({chunk.chunk_type})")
|
||||||
|
context_parts.append(chunk.content)
|
||||||
|
context_parts.append("") # Empty line separator
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def format_history(history: list[Message], max_messages: int = 10) -> str:
|
||||||
|
"""Format conversation history for the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: List of conversation messages
|
||||||
|
max_messages: Maximum messages to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted history string
|
||||||
|
"""
|
||||||
|
if not history:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Take most recent messages
|
||||||
|
recent = history[-max_messages:] if len(history) > max_messages else history
|
||||||
|
|
||||||
|
history_parts = ["## Conversation History\n"]
|
||||||
|
|
||||||
|
for msg in recent:
|
||||||
|
role = "User" if msg.role == "user" else "Assistant"
|
||||||
|
history_parts.append(f"**{role}**: {msg.content}")
|
||||||
|
history_parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(history_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_rag_prompt(
|
||||||
|
user_message: str,
|
||||||
|
intent: str,
|
||||||
|
chunks: list[RetrievedChunk],
|
||||||
|
history: list[Message],
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Build complete RAG prompt with system message and user content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: Current user message
|
||||||
|
intent: Classified intent
|
||||||
|
chunks: Retrieved context chunks
|
||||||
|
history: Conversation history
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (system_prompt, user_content)
|
||||||
|
"""
|
||||||
|
# Select system prompt
|
||||||
|
system_prompt = select_system_prompt(intent, bool(chunks))
|
||||||
|
|
||||||
|
# Build user content
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Add context if available
|
||||||
|
if chunks:
|
||||||
|
parts.append(format_context(chunks))
|
||||||
|
|
||||||
|
# Add history for clarification intent
|
||||||
|
if intent == "clarification" and history:
|
||||||
|
parts.append(format_history(history))
|
||||||
|
|
||||||
|
# Add current question
|
||||||
|
parts.append(f"## Current Question\n{user_message}")
|
||||||
|
|
||||||
|
user_content = "\n\n".join(parts)
|
||||||
|
|
||||||
|
return system_prompt, user_content
|
||||||
|
|
@ -3,6 +3,15 @@ from typing import Literal
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SourceReference(BaseModel):
|
||||||
|
"""Reference to a source artifact chunk used in RAG response."""
|
||||||
|
|
||||||
|
artifact_file: str = Field(..., description="Path to the YAML artifact file")
|
||||||
|
source_file: str = Field(..., description="Path to the source code file")
|
||||||
|
chunk_type: str = Field(..., description="Type of chunk (factual_summary, interpretive_summary, method, invariants)")
|
||||||
|
relevance_score: float = Field(..., description="Similarity score from FAISS search")
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
"""Request model for the /chat endpoint."""
|
"""Request model for the /chat endpoint."""
|
||||||
|
|
||||||
|
|
@ -18,13 +27,18 @@ class ChatResponse(BaseModel):
|
||||||
conversation_id: str = Field(..., description="Conversation ID (generated if not provided)")
|
conversation_id: str = Field(..., description="Conversation ID (generated if not provided)")
|
||||||
response: str = Field(..., description="The LLM's response")
|
response: str = Field(..., description="The LLM's response")
|
||||||
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
||||||
sources: list = Field(default_factory=list, description="Source references (empty for now)")
|
intent: Literal["codebase", "general", "clarification"] = Field(
|
||||||
|
default="general", description="Classified intent of the user message"
|
||||||
|
)
|
||||||
|
sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response")
|
||||||
|
|
||||||
|
|
||||||
class HealthResponse(BaseModel):
|
class HealthResponse(BaseModel):
|
||||||
"""Response model for the /health endpoint."""
|
"""Response model for the /health endpoint."""
|
||||||
|
|
||||||
status: str = Field(default="ok")
|
status: str = Field(default="ok")
|
||||||
|
faiss_loaded: bool = Field(default=False, description="Whether FAISS index is loaded")
|
||||||
|
index_size: int = Field(default=0, description="Number of chunks in FAISS index")
|
||||||
|
|
||||||
|
|
||||||
class ErrorResponse(BaseModel):
|
class ErrorResponse(BaseModel):
|
||||||
|
|
@ -50,6 +64,18 @@ class StreamDoneEvent(BaseModel):
|
||||||
type: Literal["done"] = "done"
|
type: Literal["done"] = "done"
|
||||||
conversation_id: str = Field(..., description="Conversation ID")
|
conversation_id: str = Field(..., description="Conversation ID")
|
||||||
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
||||||
|
intent: Literal["codebase", "general", "clarification"] = Field(
|
||||||
|
default="general", description="Classified intent of the user message"
|
||||||
|
)
|
||||||
|
sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response")
|
||||||
|
|
||||||
|
|
||||||
|
class IndexResponse(BaseModel):
|
||||||
|
"""Response model for the /index endpoint."""
|
||||||
|
|
||||||
|
status: str = Field(..., description="Indexing status")
|
||||||
|
chunks_indexed: int = Field(default=0, description="Number of chunks indexed")
|
||||||
|
artifacts_processed: int = Field(default=0, description="Number of YAML files processed")
|
||||||
|
|
||||||
|
|
||||||
class StreamErrorEvent(BaseModel):
|
class StreamErrorEvent(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -32,3 +32,4 @@ urllib3==2.6.3
|
||||||
uvicorn==0.40.0
|
uvicorn==0.40.0
|
||||||
watchfiles==1.1.1
|
watchfiles==1.1.1
|
||||||
websockets==16.0
|
websockets==16.0
|
||||||
|
faiss-cpu>=1.12.0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue