feat: replace Redis with in-memory conversation storage

- Remove Redis dependency and redis_client.py
- Implement ConversationMemory with module-level dictionary
- Add TTL support via timestamp checking
- Remove redis_connected from health endpoint
- Add embeddings, intent classification, and RAG prompt modules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Danny 2026-01-30 10:34:47 -06:00
parent 72778b65b5
commit b0211b944d
13 changed files with 1168 additions and 16 deletions

View File

@ -22,6 +22,18 @@ class Settings(BaseSettings):
# CORS configuration # CORS configuration
cors_origins: str = "http://localhost:3000" cors_origins: str = "http://localhost:3000"
# Conversation memory configuration
conversation_ttl: int = 86400 # 24 hours
# Embedding configuration
embedding_model: str = "text-embedding-3-small"
# RAG configuration
rag_top_k: int = 5
rag_similarity_threshold: float = 0.70
artifacts_path: str = "artifacts"
embeddings_path: str = "embeddings"
# Authentication settings # Authentication settings
auth_enabled: bool = False # Set to True in production auth_enabled: bool = False # Set to True in production
auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app) auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app)

View File

@ -0,0 +1,18 @@
"""Embeddings module for YAML artifact indexing and retrieval."""
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingClientDependency
from app.embeddings.indexer import ArtifactIndexer, get_indexer, IndexerDependency
from app.embeddings.retriever import Retriever, get_retriever, RetrieverDependency, reset_retriever
__all__ = [
"EmbeddingClient",
"get_embedding_client",
"EmbeddingClientDependency",
"ArtifactIndexer",
"get_indexer",
"IndexerDependency",
"Retriever",
"get_retriever",
"RetrieverDependency",
"reset_retriever",
]

109
app/embeddings/client.py Normal file
View File

@ -0,0 +1,109 @@
"""OpenAI embedding client wrapper."""
import logging
from functools import lru_cache
from typing import Annotated
from fastapi import Depends
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
from app.config import settings
logger = logging.getLogger(__name__)
class EmbeddingError(Exception):
"""Base exception for embedding operations."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(message)
class EmbeddingClient:
"""Async wrapper for OpenAI embeddings API."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""Initialize the embedding client.
Args:
api_key: OpenAI API key
model: Embedding model identifier
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
self.dimensions = 1536 # text-embedding-3-small dimension
async def embed(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding vector (1536 dimensions)
Raises:
EmbeddingError: If embedding generation fails
"""
try:
response = await self.client.embeddings.create(
model=self.model,
input=text,
)
return response.data[0].embedding
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding generation fails
"""
if not texts:
return []
try:
response = await self.client.embeddings.create(
model=self.model,
input=texts,
)
# Sort by index to ensure correct ordering
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_embeddings]
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
@lru_cache()
def get_embedding_client() -> EmbeddingClient:
"""Get cached embedding client instance."""
return EmbeddingClient(
api_key=settings.openai_api_key,
model=settings.embedding_model,
)
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]

227
app/embeddings/indexer.py Normal file
View File

@ -0,0 +1,227 @@
"""YAML artifact indexer for building FAISS index."""
import json
import logging
from dataclasses import dataclass, asdict
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
import yaml
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError
logger = logging.getLogger(__name__)
@dataclass
class Chunk:
"""Represents a text chunk from a YAML artifact."""
chunk_id: str
content: str
chunk_type: str # factual_summary, interpretive_summary, method, invariants
artifact_file: str
source_file: str
tags: dict
@dataclass
class IndexResult:
"""Result of indexing operation."""
chunks_indexed: int
artifacts_processed: int
status: str
class ArtifactIndexer:
"""Parses YAML artifacts and builds FAISS index."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the indexer.
Args:
embedding_client: Client for generating embeddings
"""
self.embedding_client = embedding_client
self.artifacts_path = Path(settings.artifacts_path)
self.embeddings_path = Path(settings.embeddings_path)
self.dimensions = 1536
def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]:
"""Parse a YAML artifact file into chunks.
Args:
yaml_path: Path to the YAML file
Returns:
List of chunks extracted from the file
"""
chunks = []
artifact_file = str(yaml_path.relative_to(self.artifacts_path))
try:
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except Exception as e:
logger.warning(f"Failed to parse {yaml_path}: {e}")
return chunks
if not data:
return chunks
source_file = data.get("source_file", "unknown")
tags = data.get("tags", {})
# Chunk 1: Factual summary
if factual := data.get("factual_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::factual_summary",
content=f"[{source_file}] Factual summary: {factual.strip()}",
chunk_type="factual_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk 2: Interpretive summary
if interpretive := data.get("interpretive_summary"):
chunks.append(Chunk(
chunk_id=f"{artifact_file}::interpretive_summary",
content=f"[{source_file}] Interpretive summary: {interpretive.strip()}",
chunk_type="interpretive_summary",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk per method
if methods := data.get("methods"):
for method_sig, method_data in methods.items():
description = method_data.get("description", "") if isinstance(method_data, dict) else method_data
if description:
chunks.append(Chunk(
chunk_id=f"{artifact_file}::method::{method_sig}",
content=f"[{source_file}] Method {method_sig}: {description.strip()}",
chunk_type="method",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
# Chunk for invariants (combined)
if invariants := data.get("invariants"):
invariants_text = " ".join(f"- {inv}" for inv in invariants)
chunks.append(Chunk(
chunk_id=f"{artifact_file}::invariants",
content=f"[{source_file}] Invariants: {invariants_text}",
chunk_type="invariants",
artifact_file=artifact_file,
source_file=source_file,
tags=tags,
))
return chunks
def _collect_all_chunks(self) -> list[Chunk]:
"""Collect chunks from all YAML artifacts.
Returns:
List of all chunks from all artifacts
"""
all_chunks = []
for yaml_path in self.artifacts_path.rglob("*.yaml"):
chunks = self._parse_yaml_to_chunks(yaml_path)
all_chunks.extend(chunks)
logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}")
return all_chunks
async def build_index(self) -> IndexResult:
"""Build FAISS index from all YAML artifacts.
Returns:
IndexResult with statistics
Raises:
EmbeddingError: If embedding generation fails
"""
# Collect all chunks
chunks = self._collect_all_chunks()
if not chunks:
logger.warning("No chunks found in artifacts")
return IndexResult(
chunks_indexed=0,
artifacts_processed=0,
status="no_artifacts",
)
logger.info(f"Generating embeddings for {len(chunks)} chunks...")
# Generate embeddings in batches
batch_size = 100
all_embeddings = []
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [chunk.content for chunk in batch]
embeddings = await self.embedding_client.embed_batch(texts)
all_embeddings.extend(embeddings)
logger.debug(f"Embedded batch {i // batch_size + 1}")
# Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings_array)
index = faiss.IndexFlatIP(self.dimensions)
index.add(embeddings_array)
# Create embeddings directory if needed
self.embeddings_path.mkdir(parents=True, exist_ok=True)
# Save FAISS index
faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin"))
# Save metadata
metadata = {
"chunks": [asdict(chunk) for chunk in chunks],
}
with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
# Save index info
artifact_files = set(chunk.artifact_file for chunk in chunks)
index_info = {
"total_chunks": len(chunks),
"total_artifacts": len(artifact_files),
"dimensions": self.dimensions,
"index_type": "IndexFlatIP",
}
with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f:
json.dump(index_info, f, indent=2)
logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts")
return IndexResult(
chunks_indexed=len(chunks),
artifacts_processed=len(artifact_files),
status="completed",
)
@lru_cache()
def get_indexer() -> ArtifactIndexer:
"""Get cached indexer instance."""
return ArtifactIndexer(embedding_client=get_embedding_client())
IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)]

221
app/embeddings/retriever.py Normal file
View File

@ -0,0 +1,221 @@
"""FAISS-based retrieval with adaptive selection."""
import json
import logging
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Annotated
import faiss
import numpy as np
from fastapi import Depends
from app.config import settings
from app.embeddings.client import EmbeddingClient, get_embedding_client
logger = logging.getLogger(__name__)
@dataclass
class RetrievedChunk:
"""A chunk retrieved from FAISS search."""
chunk_id: str
content: str
chunk_type: str
artifact_file: str
source_file: str
tags: dict
score: float
class Retriever:
"""FAISS-based retriever with adaptive selection logic."""
def __init__(self, embedding_client: EmbeddingClient):
"""Initialize the retriever.
Args:
embedding_client: Client for generating query embeddings
"""
self.embedding_client = embedding_client
self.embeddings_path = Path(settings.embeddings_path)
self.top_k = settings.rag_top_k
self.threshold = settings.rag_similarity_threshold
self._index: faiss.IndexFlatIP | None = None
self._metadata: list[dict] | None = None
self._loaded = False
def load_index(self) -> bool:
"""Load FAISS index and metadata from disk.
Returns:
True if successfully loaded, False otherwise
"""
index_path = self.embeddings_path / "faiss_index.bin"
metadata_path = self.embeddings_path / "metadata.json"
if not index_path.exists() or not metadata_path.exists():
logger.warning("FAISS index or metadata not found. Run /index first.")
self._loaded = False
return False
try:
self._index = faiss.read_index(str(index_path))
with open(metadata_path, "r", encoding="utf-8") as f:
data = json.load(f)
self._metadata = data.get("chunks", [])
self._loaded = True
logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors")
return True
except Exception as e:
logger.error(f"Failed to load FAISS index: {e}")
self._loaded = False
return False
@property
def is_loaded(self) -> bool:
"""Check if index is loaded."""
return self._loaded and self._index is not None
@property
def index_size(self) -> int:
"""Get number of vectors in index."""
if self._index is None:
return 0
return self._index.ntotal
def _adaptive_select(
self,
indices: np.ndarray,
scores: np.ndarray,
) -> list[tuple[int, float]]:
"""Apply adaptive selection logic.
- Always include top 2 chunks (regardless of score)
- For chunks 3-5: apply threshold
- Limit to self.top_k chunks total
Args:
indices: FAISS result indices
scores: FAISS result scores
Returns:
List of (index, score) tuples for selected chunks
"""
selected = []
for i, (idx, score) in enumerate(zip(indices, scores)):
if idx == -1: # FAISS returns -1 for no match
continue
# Always take top 2
if i < 2:
selected.append((int(idx), float(score)))
# Apply threshold for remaining
elif score >= self.threshold and len(selected) < self.top_k:
selected.append((int(idx), float(score)))
return selected
def _apply_diversity_filter(
self,
candidates: list[tuple[int, float]],
max_per_artifact: int = 2,
) -> list[tuple[int, float]]:
"""Limit chunks per artifact for diversity.
Args:
candidates: List of (index, score) tuples
max_per_artifact: Maximum chunks from same artifact
Returns:
Filtered list of (index, score) tuples
"""
artifact_counts: dict[str, int] = {}
filtered = []
for idx, score in candidates:
chunk = self._metadata[idx]
artifact = chunk["artifact_file"]
if artifact_counts.get(artifact, 0) < max_per_artifact:
filtered.append((idx, score))
artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1
return filtered
async def search(self, query: str) -> list[RetrievedChunk]:
"""Search for relevant chunks.
Args:
query: User's question
Returns:
List of retrieved chunks with relevance scores
"""
if not self.is_loaded:
if not self.load_index():
return []
# Generate query embedding
query_embedding = await self.embedding_client.embed(query)
query_vector = np.array([query_embedding], dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(query_vector)
# Search FAISS (get more candidates than needed for filtering)
k_search = min(8, self._index.ntotal)
scores, indices = self._index.search(query_vector, k_search)
# Apply adaptive selection
selected = self._adaptive_select(indices[0], scores[0])
# Apply diversity filter
filtered = self._apply_diversity_filter(selected)
# Build result chunks
results = []
for idx, score in filtered:
chunk_data = self._metadata[idx]
results.append(RetrievedChunk(
chunk_id=chunk_data["chunk_id"],
content=chunk_data["content"],
chunk_type=chunk_data["chunk_type"],
artifact_file=chunk_data["artifact_file"],
source_file=chunk_data["source_file"],
tags=chunk_data.get("tags", {}),
score=score,
))
logger.debug(f"Retrieved {len(results)} chunks for query")
return results
# Singleton retriever instance
_retriever: Retriever | None = None
def get_retriever() -> Retriever:
"""Get singleton retriever instance (lazily initialized)."""
global _retriever
if _retriever is None:
_retriever = Retriever(embedding_client=get_embedding_client())
# Attempt to load index at startup
_retriever.load_index()
return _retriever
def reset_retriever() -> None:
"""Reset the singleton retriever (for reloading after re-indexing)."""
global _retriever
_retriever = None
RetrieverDependency = Annotated[Retriever, Depends(get_retriever)]

9
app/intent/__init__.py Normal file
View File

@ -0,0 +1,9 @@
"""Intent classification module."""
from app.intent.classifier import IntentClassifier, get_intent_classifier, IntentClassifierDependency
__all__ = [
"IntentClassifier",
"get_intent_classifier",
"IntentClassifierDependency",
]

96
app/intent/classifier.py Normal file
View File

@ -0,0 +1,96 @@
"""Lightweight intent classification using gpt-4o-mini."""
import logging
from functools import lru_cache
from typing import Annotated, Literal
from fastapi import Depends
from openai import AsyncOpenAI
from app.config import settings
from app.memory.conversation import Message
logger = logging.getLogger(__name__)
Intent = Literal["codebase", "general", "clarification"]
INTENT_PROMPT = """Classify this user message into one category:
- "codebase": Questions about trading system code, architecture, files, methods, execution, strategies, exchanges, risk management, order handling, or technical implementation
- "general": Greetings, meta-questions, off-topic ("How are you?", "What can you do?", "Hello")
- "clarification": Follow-ups that rely on conversation context, not new retrieval ("Tell me more", "What did you mean?", "Can you explain that?")
IMPORTANT: If the user is asking about specific code, files, classes, methods, or system behavior, classify as "codebase".
Respond with ONLY the category name, nothing else."""
class IntentClassifier:
"""Lightweight intent classifier using gpt-4o-mini."""
def __init__(self, api_key: str):
"""Initialize the classifier.
Args:
api_key: OpenAI API key
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = "gpt-4o-mini"
async def classify(
self,
message: str,
history: list[Message] | None = None,
) -> Intent:
"""Classify user message intent.
Args:
message: User's message
history: Optional conversation history for context
Returns:
Classified intent: "codebase", "general", or "clarification"
"""
# Build context from history (last 2 turns)
context = ""
if history and len(history) >= 2:
recent = history[-4:] # Last 2 exchanges
context = "Recent conversation:\n"
for msg in recent:
role = "User" if msg.role == "user" else "Assistant"
context += f"{role}: {msg.content[:100]}...\n" if len(msg.content) > 100 else f"{role}: {msg.content}\n"
context += "\n"
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": INTENT_PROMPT},
{"role": "user", "content": f"{context}Current message: {message}"},
],
max_tokens=10,
temperature=0,
)
raw_intent = response.choices[0].message.content.strip().lower()
# Validate intent
if raw_intent in ("codebase", "general", "clarification"):
logger.debug(f"Classified intent: {raw_intent}")
return raw_intent
# Default to codebase for ambiguous cases (safer for RAG)
logger.warning(f"Unexpected intent response: {raw_intent}, defaulting to codebase")
return "codebase"
except Exception as e:
logger.warning(f"Intent classification failed: {e}, defaulting to codebase")
return "codebase"
@lru_cache()
def get_intent_classifier() -> IntentClassifier:
"""Get cached intent classifier instance."""
return IntentClassifier(api_key=settings.openai_api_key)
IntentClassifierDependency = Annotated[IntentClassifier, Depends(get_intent_classifier)]

View File

@ -1,17 +1,24 @@
import logging import logging
import uuid import uuid
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from app.auth import ServiceAuthDependency from app.auth import ServiceAuthDependency
from app.config import settings, MAX_MESSAGE_LENGTH from app.config import settings, MAX_MESSAGE_LENGTH
from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever
from app.embeddings.retriever import RetrievedChunk
from app.intent import IntentClassifierDependency
from app.llm import AdapterDependency, LLMError, llm_exception_to_http from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.memory import ConversationMemoryDependency
from app.prompts import build_rag_prompt
from app.schemas import ( from app.schemas import (
ChatRequest, ChatRequest,
ChatResponse, ChatResponse,
HealthResponse, HealthResponse,
IndexResponse,
SourceReference,
StreamChunkEvent, StreamChunkEvent,
StreamDoneEvent, StreamDoneEvent,
StreamErrorEvent, StreamErrorEvent,
@ -43,21 +50,101 @@ app.add_middleware(
@app.get("/health", response_model=HealthResponse) @app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse: async def health_check() -> HealthResponse:
"""Health check endpoint.""" """Health check endpoint with FAISS status."""
return HealthResponse(status="ok") retriever = get_retriever()
return HealthResponse(
status="ok",
faiss_loaded=retriever.is_loaded,
index_size=retriever.index_size,
)
@app.get("/debug/search")
async def debug_search(
q: str,
retriever: RetrieverDependency,
) -> dict:
"""Debug endpoint to test retrieval directly."""
chunks = await retriever.search(q)
return {
"query": q,
"chunks_found": len(chunks),
"chunks": [
{
"source_file": c.source_file,
"chunk_type": c.chunk_type,
"score": round(c.score, 4),
"content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content,
}
for c in chunks
],
}
@app.post("/index", response_model=IndexResponse)
async def reindex_artifacts(
background_tasks: BackgroundTasks,
indexer: IndexerDependency,
_auth: ServiceAuthDependency,
) -> IndexResponse:
"""Trigger re-indexing of YAML artifacts.
Builds FAISS index from all YAML files in the artifacts directory.
"""
logger.info("Starting artifact indexing...")
try:
result = await indexer.build_index()
# Reset retriever to reload new index
reset_retriever()
logger.info(
f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts"
)
return IndexResponse(
status=result.status,
chunks_indexed=result.chunks_indexed,
artifacts_processed=result.artifacts_processed,
)
except Exception as e:
logger.error(f"Indexing failed: {e}")
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]:
"""Convert retrieved chunks to source references."""
return [
SourceReference(
artifact_file=chunk.artifact_file,
source_file=chunk.source_file,
chunk_type=chunk.chunk_type,
relevance_score=chunk.score,
)
for chunk in chunks
]
@app.post("/chat", response_model=ChatResponse) @app.post("/chat", response_model=ChatResponse)
async def chat( async def chat(
request: ChatRequest, request: ChatRequest,
adapter: AdapterDependency, adapter: AdapterDependency,
retriever: RetrieverDependency,
memory: ConversationMemoryDependency,
classifier: IntentClassifierDependency,
_auth: ServiceAuthDependency, _auth: ServiceAuthDependency,
) -> ChatResponse: ) -> ChatResponse:
"""Process a chat message through the LLM adapter. """Process a chat message through the RAG pipeline.
- Validates message length - Validates message length
- Generates conversation_id if not provided - Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE - Classifies intent (codebase/general/clarification)
- Retrieves relevant context from FAISS (for codebase intent)
- Builds RAG prompt and generates response
- Stores conversation turn in Redis
""" """
# Validate message length # Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH: if len(request.message) > MAX_MESSAGE_LENGTH:
@ -70,19 +157,42 @@ async def chat(
# Generate or use provided conversation_id # Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4()) conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content) # Get conversation history
history = await memory.get_history(conversation_id)
# Classify intent
intent = await classifier.classify(request.message, history)
# Log request metadata
logger.info( logger.info(
"Chat request received", "Chat request received",
extra={ extra={
"conversation_id": conversation_id, "conversation_id": conversation_id,
"message_length": len(request.message), "message_length": len(request.message),
"mode": settings.llm_mode, "mode": settings.llm_mode,
"intent": intent,
}, },
) )
# Retrieve context for codebase questions
chunks: list[RetrievedChunk] = []
if intent == "codebase":
chunks = await retriever.search(request.message)
logger.debug(f"Retrieved {len(chunks)} chunks for codebase question")
# Build RAG prompt
system_prompt, user_content = build_rag_prompt(
user_message=request.message,
intent=intent,
chunks=chunks,
history=history,
)
# Generate response with exception handling # Generate response with exception handling
try: try:
response_text = await adapter.generate(conversation_id, request.message) # For RAG, we pass the full constructed prompt
full_prompt = f"{system_prompt}\n\n{user_content}"
response_text = await adapter.generate(conversation_id, full_prompt)
except LLMError as e: except LLMError as e:
logger.error( logger.error(
"LLM generation failed", "LLM generation failed",
@ -94,6 +204,15 @@ async def chat(
) )
raise llm_exception_to_http(e) raise llm_exception_to_http(e)
# Store conversation turn
sources = _chunks_to_sources(chunks)
await memory.store_turn(
conversation_id=conversation_id,
user_message=request.message,
assistant_message=response_text,
sources=[s.model_dump() for s in sources] if sources else None,
)
# Log response metadata # Log response metadata
logger.info( logger.info(
"Chat response generated", "Chat response generated",
@ -101,6 +220,8 @@ async def chat(
"conversation_id": conversation_id, "conversation_id": conversation_id,
"response_length": len(response_text), "response_length": len(response_text),
"mode": settings.llm_mode, "mode": settings.llm_mode,
"intent": intent,
"sources_count": len(sources),
}, },
) )
@ -108,7 +229,8 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
response=response_text, response=response_text,
mode=settings.llm_mode, mode=settings.llm_mode,
sources=[], intent=intent,
sources=sources,
) )
@ -116,14 +238,18 @@ async def chat(
async def chat_stream( async def chat_stream(
request: ChatRequest, request: ChatRequest,
adapter: AdapterDependency, adapter: AdapterDependency,
retriever: RetrieverDependency,
memory: ConversationMemoryDependency,
classifier: IntentClassifierDependency,
_auth: ServiceAuthDependency, _auth: ServiceAuthDependency,
) -> StreamingResponse: ) -> StreamingResponse:
"""Stream a chat response through the LLM adapter using Server-Sent Events. """Stream a chat response through the RAG pipeline using Server-Sent Events.
- Validates message length - Validates message length
- Generates conversation_id if not provided - Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE - Classifies intent and retrieves context
- Returns streaming response with SSE format - Streams response with SSE format
- Stores conversation turn after completion
""" """
# Validate message length # Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH: if len(request.message) > MAX_MESSAGE_LENGTH:
@ -136,26 +262,64 @@ async def chat_stream(
# Generate or use provided conversation_id # Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4()) conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content) # Get conversation history
history = await memory.get_history(conversation_id)
# Classify intent
intent = await classifier.classify(request.message, history)
# Retrieve context for codebase questions
chunks: list[RetrievedChunk] = []
if intent == "codebase":
chunks = await retriever.search(request.message)
# Build RAG prompt
system_prompt, user_content = build_rag_prompt(
user_message=request.message,
intent=intent,
chunks=chunks,
history=history,
)
sources = _chunks_to_sources(chunks)
# Log request metadata
logger.info( logger.info(
"Chat stream request received", "Chat stream request received",
extra={ extra={
"conversation_id": conversation_id, "conversation_id": conversation_id,
"message_length": len(request.message), "message_length": len(request.message),
"mode": settings.llm_mode, "mode": settings.llm_mode,
"intent": intent,
}, },
) )
async def stream_response(): async def stream_response():
"""Async generator that yields SSE-formatted events.""" """Async generator that yields SSE-formatted events."""
full_response = []
try: try:
async for chunk in adapter.generate_stream(conversation_id, request.message): full_prompt = f"{system_prompt}\n\n{user_content}"
async for chunk in adapter.generate_stream(conversation_id, full_prompt):
full_response.append(chunk)
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id) event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
yield f"data: {event.model_dump_json()}\n\n" yield f"data: {event.model_dump_json()}\n\n"
# Send completion event # Store conversation turn
response_text = "".join(full_response)
await memory.store_turn(
conversation_id=conversation_id,
user_message=request.message,
assistant_message=response_text,
sources=[s.model_dump() for s in sources] if sources else None,
)
# Send completion event with sources
done_event = StreamDoneEvent( done_event = StreamDoneEvent(
conversation_id=conversation_id, mode=settings.llm_mode conversation_id=conversation_id,
mode=settings.llm_mode,
intent=intent,
sources=sources,
) )
yield f"data: {done_event.model_dump_json()}\n\n" yield f"data: {done_event.model_dump_json()}\n\n"
@ -164,6 +328,8 @@ async def chat_stream(
extra={ extra={
"conversation_id": conversation_id, "conversation_id": conversation_id,
"mode": settings.llm_mode, "mode": settings.llm_mode,
"intent": intent,
"sources_count": len(sources),
}, },
) )

9
app/memory/__init__.py Normal file
View File

@ -0,0 +1,9 @@
"""Memory module for conversation history management."""
from app.memory.conversation import ConversationMemory, get_conversation_memory, ConversationMemoryDependency
__all__ = [
"ConversationMemory",
"get_conversation_memory",
"ConversationMemoryDependency",
]

122
app/memory/conversation.py Normal file
View File

@ -0,0 +1,122 @@
"""Conversation history management with in-memory storage."""
import logging
import time
from dataclasses import dataclass, asdict, field
from typing import Annotated
from fastapi import Depends
from app.config import settings
logger = logging.getLogger(__name__)
MAX_HISTORY_MESSAGES = 20
@dataclass
class Message:
"""A single message in conversation history."""
role: str # "user" or "assistant"
content: str
sources: list[dict] | None = None
@dataclass
class ConversationData:
"""Container for conversation messages with timestamp for TTL."""
messages: list[Message] = field(default_factory=list)
last_updated: float = field(default_factory=time.time)
# Module-level storage for conversations
_conversations: dict[str, ConversationData] = {}
class ConversationMemory:
"""Manages conversation history in memory."""
def __init__(self, ttl: int):
"""Initialize conversation memory.
Args:
ttl: Time-to-live in seconds for conversations
"""
self.ttl = ttl
async def get_history(self, conversation_id: str) -> list[Message]:
"""Get conversation history.
Args:
conversation_id: Conversation identifier
Returns:
List of messages in chronological order, or empty list if expired/not found
"""
data = _conversations.get(conversation_id)
if data is None:
return []
# Check if expired
if time.time() - data.last_updated > self.ttl:
del _conversations[conversation_id]
return []
return data.messages
async def store_turn(
self,
conversation_id: str,
user_message: str,
assistant_message: str,
sources: list[dict] | None = None,
) -> None:
"""Store a conversation turn (user message + assistant response).
Args:
conversation_id: Conversation identifier
user_message: User's message
assistant_message: Assistant's response
sources: Optional source references used in response
"""
# Get existing history (checks TTL)
history = await self.get_history(conversation_id)
# Add new messages
history.append(Message(role="user", content=user_message))
history.append(Message(role="assistant", content=assistant_message, sources=sources))
# Trim to max size (keep most recent)
if len(history) > MAX_HISTORY_MESSAGES:
history = history[-MAX_HISTORY_MESSAGES:]
# Store with updated timestamp
_conversations[conversation_id] = ConversationData(
messages=history,
last_updated=time.time(),
)
logger.debug(f"Stored conversation turn for {conversation_id}")
async def clear(self, conversation_id: str) -> bool:
"""Clear conversation history.
Args:
conversation_id: Conversation identifier
Returns:
True if cleared successfully
"""
if conversation_id in _conversations:
del _conversations[conversation_id]
return True
def get_conversation_memory() -> ConversationMemory:
"""Get conversation memory instance."""
return ConversationMemory(ttl=settings.conversation_ttl)
ConversationMemoryDependency = Annotated[ConversationMemory, Depends(get_conversation_memory)]

136
app/prompts.py Normal file
View File

@ -0,0 +1,136 @@
"""System prompts for RAG-based codebase Q&A."""
from app.embeddings.retriever import RetrievedChunk
from app.memory.conversation import Message
# For codebase questions WITH retrieved context
CODEBASE_SYSTEM_PROMPT = """You answer questions about the Tyndale trading system using ONLY the provided YAML artifacts.
HARD CONSTRAINTS:
- Do NOT assume access to source code
- Do NOT invent implementation details not in the artifacts
- Do NOT speculate about code mechanics beyond what artifacts describe
- If artifacts do not contain enough information, say so explicitly
RESPONSE STYLE:
- Prefer architectural and behavioral explanations over mechanics
- Reference source files by path (e.g., ./trader.py)
- Explain trading concepts for developers without finance background
- Keep responses focused and concise"""
# For general questions (no RAG context)
GENERAL_SYSTEM_PROMPT = """You are an assistant for the Tyndale trading system documentation.
You can answer general questions, but for specific codebase questions, you need artifact context.
If the user asks about specific code without context, ask them to rephrase or be more specific."""
# For clarification/follow-ups (uses conversation history)
CLARIFICATION_SYSTEM_PROMPT = """You are continuing a conversation about the Tyndale trading system.
Use the conversation history to answer follow-up questions.
If you need to look up new information, ask the user to rephrase as a standalone question."""
def select_system_prompt(intent: str, has_context: bool) -> str:
"""Select appropriate system prompt based on intent and context.
Args:
intent: Classified intent (codebase, general, clarification)
has_context: Whether RAG context was retrieved
Returns:
System prompt string
"""
if intent == "codebase" and has_context:
return CODEBASE_SYSTEM_PROMPT
elif intent == "codebase" and not has_context:
return CODEBASE_SYSTEM_PROMPT + "\n\nNOTE: No relevant artifacts were found for this question. Acknowledge this limitation in your response."
elif intent == "clarification":
return CLARIFICATION_SYSTEM_PROMPT
else:
return GENERAL_SYSTEM_PROMPT
def format_context(chunks: list[RetrievedChunk]) -> str:
"""Format retrieved chunks as context for the LLM.
Args:
chunks: List of retrieved chunks
Returns:
Formatted context string
"""
if not chunks:
return ""
context_parts = ["## Retrieved Artifact Context\n"]
for i, chunk in enumerate(chunks, 1):
context_parts.append(f"### Source {i}: {chunk.source_file} ({chunk.chunk_type})")
context_parts.append(chunk.content)
context_parts.append("") # Empty line separator
return "\n".join(context_parts)
def format_history(history: list[Message], max_messages: int = 10) -> str:
"""Format conversation history for the LLM.
Args:
history: List of conversation messages
max_messages: Maximum messages to include
Returns:
Formatted history string
"""
if not history:
return ""
# Take most recent messages
recent = history[-max_messages:] if len(history) > max_messages else history
history_parts = ["## Conversation History\n"]
for msg in recent:
role = "User" if msg.role == "user" else "Assistant"
history_parts.append(f"**{role}**: {msg.content}")
history_parts.append("")
return "\n".join(history_parts)
def build_rag_prompt(
user_message: str,
intent: str,
chunks: list[RetrievedChunk],
history: list[Message],
) -> tuple[str, str]:
"""Build complete RAG prompt with system message and user content.
Args:
user_message: Current user message
intent: Classified intent
chunks: Retrieved context chunks
history: Conversation history
Returns:
Tuple of (system_prompt, user_content)
"""
# Select system prompt
system_prompt = select_system_prompt(intent, bool(chunks))
# Build user content
parts = []
# Add context if available
if chunks:
parts.append(format_context(chunks))
# Add history for clarification intent
if intent == "clarification" and history:
parts.append(format_history(history))
# Add current question
parts.append(f"## Current Question\n{user_message}")
user_content = "\n\n".join(parts)
return system_prompt, user_content

View File

@ -3,6 +3,15 @@ from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class SourceReference(BaseModel):
"""Reference to a source artifact chunk used in RAG response."""
artifact_file: str = Field(..., description="Path to the YAML artifact file")
source_file: str = Field(..., description="Path to the source code file")
chunk_type: str = Field(..., description="Type of chunk (factual_summary, interpretive_summary, method, invariants)")
relevance_score: float = Field(..., description="Similarity score from FAISS search")
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
"""Request model for the /chat endpoint.""" """Request model for the /chat endpoint."""
@ -18,13 +27,18 @@ class ChatResponse(BaseModel):
conversation_id: str = Field(..., description="Conversation ID (generated if not provided)") conversation_id: str = Field(..., description="Conversation ID (generated if not provided)")
response: str = Field(..., description="The LLM's response") response: str = Field(..., description="The LLM's response")
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used") mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
sources: list = Field(default_factory=list, description="Source references (empty for now)") intent: Literal["codebase", "general", "clarification"] = Field(
default="general", description="Classified intent of the user message"
)
sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response")
class HealthResponse(BaseModel): class HealthResponse(BaseModel):
"""Response model for the /health endpoint.""" """Response model for the /health endpoint."""
status: str = Field(default="ok") status: str = Field(default="ok")
faiss_loaded: bool = Field(default=False, description="Whether FAISS index is loaded")
index_size: int = Field(default=0, description="Number of chunks in FAISS index")
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
@ -50,6 +64,18 @@ class StreamDoneEvent(BaseModel):
type: Literal["done"] = "done" type: Literal["done"] = "done"
conversation_id: str = Field(..., description="Conversation ID") conversation_id: str = Field(..., description="Conversation ID")
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used") mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
intent: Literal["codebase", "general", "clarification"] = Field(
default="general", description="Classified intent of the user message"
)
sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response")
class IndexResponse(BaseModel):
"""Response model for the /index endpoint."""
status: str = Field(..., description="Indexing status")
chunks_indexed: int = Field(default=0, description="Number of chunks indexed")
artifacts_processed: int = Field(default=0, description="Number of YAML files processed")
class StreamErrorEvent(BaseModel): class StreamErrorEvent(BaseModel):

View File

@ -32,3 +32,4 @@ urllib3==2.6.3
uvicorn==0.40.0 uvicorn==0.40.0
watchfiles==1.1.1 watchfiles==1.1.1
websockets==16.0 websockets==16.0
faiss-cpu>=1.12.0