diff --git a/app/config.py b/app/config.py index 93ba45f..2621623 100644 --- a/app/config.py +++ b/app/config.py @@ -22,6 +22,18 @@ class Settings(BaseSettings): # CORS configuration cors_origins: str = "http://localhost:3000" + # Conversation memory configuration + conversation_ttl: int = 86400 # 24 hours + + # Embedding configuration + embedding_model: str = "text-embedding-3-small" + + # RAG configuration + rag_top_k: int = 5 + rag_similarity_threshold: float = 0.70 + artifacts_path: str = "artifacts" + embeddings_path: str = "embeddings" + # Authentication settings auth_enabled: bool = False # Set to True in production auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app) diff --git a/app/embeddings/__init__.py b/app/embeddings/__init__.py new file mode 100644 index 0000000..8816a57 --- /dev/null +++ b/app/embeddings/__init__.py @@ -0,0 +1,18 @@ +"""Embeddings module for YAML artifact indexing and retrieval.""" + +from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingClientDependency +from app.embeddings.indexer import ArtifactIndexer, get_indexer, IndexerDependency +from app.embeddings.retriever import Retriever, get_retriever, RetrieverDependency, reset_retriever + +__all__ = [ + "EmbeddingClient", + "get_embedding_client", + "EmbeddingClientDependency", + "ArtifactIndexer", + "get_indexer", + "IndexerDependency", + "Retriever", + "get_retriever", + "RetrieverDependency", + "reset_retriever", +] diff --git a/app/embeddings/client.py b/app/embeddings/client.py new file mode 100644 index 0000000..b157753 --- /dev/null +++ b/app/embeddings/client.py @@ -0,0 +1,109 @@ +"""OpenAI embedding client wrapper.""" + +import logging +from functools import lru_cache +from typing import Annotated + +from fastapi import Depends +from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError + +from app.config import settings + +logger = logging.getLogger(__name__) + + +class EmbeddingError(Exception): + """Base exception for embedding operations.""" + + def __init__(self, message: str, status_code: int = 500): + self.message = message + self.status_code = status_code + super().__init__(message) + + +class EmbeddingClient: + """Async wrapper for OpenAI embeddings API.""" + + def __init__(self, api_key: str, model: str = "text-embedding-3-small"): + """Initialize the embedding client. + + Args: + api_key: OpenAI API key + model: Embedding model identifier + """ + self.client = AsyncOpenAI(api_key=api_key) + self.model = model + self.dimensions = 1536 # text-embedding-3-small dimension + + async def embed(self, text: str) -> list[float]: + """Generate embedding for a single text. + + Args: + text: Text to embed + + Returns: + Embedding vector (1536 dimensions) + + Raises: + EmbeddingError: If embedding generation fails + """ + try: + response = await self.client.embeddings.create( + model=self.model, + input=text, + ) + return response.data[0].embedding + + except AuthenticationError as e: + raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401) + except RateLimitError as e: + raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429) + except APIConnectionError as e: + raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503) + except APIError as e: + raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500) + + async def embed_batch(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for multiple texts. + + Args: + texts: List of texts to embed + + Returns: + List of embedding vectors + + Raises: + EmbeddingError: If embedding generation fails + """ + if not texts: + return [] + + try: + response = await self.client.embeddings.create( + model=self.model, + input=texts, + ) + # Sort by index to ensure correct ordering + sorted_embeddings = sorted(response.data, key=lambda x: x.index) + return [item.embedding for item in sorted_embeddings] + + except AuthenticationError as e: + raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401) + except RateLimitError as e: + raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429) + except APIConnectionError as e: + raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503) + except APIError as e: + raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500) + + +@lru_cache() +def get_embedding_client() -> EmbeddingClient: + """Get cached embedding client instance.""" + return EmbeddingClient( + api_key=settings.openai_api_key, + model=settings.embedding_model, + ) + + +EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)] diff --git a/app/embeddings/indexer.py b/app/embeddings/indexer.py new file mode 100644 index 0000000..47244cf --- /dev/null +++ b/app/embeddings/indexer.py @@ -0,0 +1,227 @@ +"""YAML artifact indexer for building FAISS index.""" + +import json +import logging +from dataclasses import dataclass, asdict +from functools import lru_cache +from pathlib import Path +from typing import Annotated + +import faiss +import numpy as np +import yaml +from fastapi import Depends + +from app.config import settings +from app.embeddings.client import EmbeddingClient, get_embedding_client, EmbeddingError + +logger = logging.getLogger(__name__) + + +@dataclass +class Chunk: + """Represents a text chunk from a YAML artifact.""" + + chunk_id: str + content: str + chunk_type: str # factual_summary, interpretive_summary, method, invariants + artifact_file: str + source_file: str + tags: dict + + +@dataclass +class IndexResult: + """Result of indexing operation.""" + + chunks_indexed: int + artifacts_processed: int + status: str + + +class ArtifactIndexer: + """Parses YAML artifacts and builds FAISS index.""" + + def __init__(self, embedding_client: EmbeddingClient): + """Initialize the indexer. + + Args: + embedding_client: Client for generating embeddings + """ + self.embedding_client = embedding_client + self.artifacts_path = Path(settings.artifacts_path) + self.embeddings_path = Path(settings.embeddings_path) + self.dimensions = 1536 + + def _parse_yaml_to_chunks(self, yaml_path: Path) -> list[Chunk]: + """Parse a YAML artifact file into chunks. + + Args: + yaml_path: Path to the YAML file + + Returns: + List of chunks extracted from the file + """ + chunks = [] + artifact_file = str(yaml_path.relative_to(self.artifacts_path)) + + try: + with open(yaml_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + except Exception as e: + logger.warning(f"Failed to parse {yaml_path}: {e}") + return chunks + + if not data: + return chunks + + source_file = data.get("source_file", "unknown") + tags = data.get("tags", {}) + + # Chunk 1: Factual summary + if factual := data.get("factual_summary"): + chunks.append(Chunk( + chunk_id=f"{artifact_file}::factual_summary", + content=f"[{source_file}] Factual summary: {factual.strip()}", + chunk_type="factual_summary", + artifact_file=artifact_file, + source_file=source_file, + tags=tags, + )) + + # Chunk 2: Interpretive summary + if interpretive := data.get("interpretive_summary"): + chunks.append(Chunk( + chunk_id=f"{artifact_file}::interpretive_summary", + content=f"[{source_file}] Interpretive summary: {interpretive.strip()}", + chunk_type="interpretive_summary", + artifact_file=artifact_file, + source_file=source_file, + tags=tags, + )) + + # Chunk per method + if methods := data.get("methods"): + for method_sig, method_data in methods.items(): + description = method_data.get("description", "") if isinstance(method_data, dict) else method_data + if description: + chunks.append(Chunk( + chunk_id=f"{artifact_file}::method::{method_sig}", + content=f"[{source_file}] Method {method_sig}: {description.strip()}", + chunk_type="method", + artifact_file=artifact_file, + source_file=source_file, + tags=tags, + )) + + # Chunk for invariants (combined) + if invariants := data.get("invariants"): + invariants_text = " ".join(f"- {inv}" for inv in invariants) + chunks.append(Chunk( + chunk_id=f"{artifact_file}::invariants", + content=f"[{source_file}] Invariants: {invariants_text}", + chunk_type="invariants", + artifact_file=artifact_file, + source_file=source_file, + tags=tags, + )) + + return chunks + + def _collect_all_chunks(self) -> list[Chunk]: + """Collect chunks from all YAML artifacts. + + Returns: + List of all chunks from all artifacts + """ + all_chunks = [] + + for yaml_path in self.artifacts_path.rglob("*.yaml"): + chunks = self._parse_yaml_to_chunks(yaml_path) + all_chunks.extend(chunks) + logger.debug(f"Parsed {len(chunks)} chunks from {yaml_path}") + + return all_chunks + + async def build_index(self) -> IndexResult: + """Build FAISS index from all YAML artifacts. + + Returns: + IndexResult with statistics + + Raises: + EmbeddingError: If embedding generation fails + """ + # Collect all chunks + chunks = self._collect_all_chunks() + + if not chunks: + logger.warning("No chunks found in artifacts") + return IndexResult( + chunks_indexed=0, + artifacts_processed=0, + status="no_artifacts", + ) + + logger.info(f"Generating embeddings for {len(chunks)} chunks...") + + # Generate embeddings in batches + batch_size = 100 + all_embeddings = [] + + for i in range(0, len(chunks), batch_size): + batch = chunks[i:i + batch_size] + texts = [chunk.content for chunk in batch] + embeddings = await self.embedding_client.embed_batch(texts) + all_embeddings.extend(embeddings) + logger.debug(f"Embedded batch {i // batch_size + 1}") + + # Build FAISS index (IndexFlatIP for inner product / cosine similarity on normalized vectors) + embeddings_array = np.array(all_embeddings, dtype=np.float32) + + # Normalize for cosine similarity + faiss.normalize_L2(embeddings_array) + + index = faiss.IndexFlatIP(self.dimensions) + index.add(embeddings_array) + + # Create embeddings directory if needed + self.embeddings_path.mkdir(parents=True, exist_ok=True) + + # Save FAISS index + faiss.write_index(index, str(self.embeddings_path / "faiss_index.bin")) + + # Save metadata + metadata = { + "chunks": [asdict(chunk) for chunk in chunks], + } + with open(self.embeddings_path / "metadata.json", "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + # Save index info + artifact_files = set(chunk.artifact_file for chunk in chunks) + index_info = { + "total_chunks": len(chunks), + "total_artifacts": len(artifact_files), + "dimensions": self.dimensions, + "index_type": "IndexFlatIP", + } + with open(self.embeddings_path / "index_info.json", "w", encoding="utf-8") as f: + json.dump(index_info, f, indent=2) + + logger.info(f"Indexed {len(chunks)} chunks from {len(artifact_files)} artifacts") + + return IndexResult( + chunks_indexed=len(chunks), + artifacts_processed=len(artifact_files), + status="completed", + ) + + +@lru_cache() +def get_indexer() -> ArtifactIndexer: + """Get cached indexer instance.""" + return ArtifactIndexer(embedding_client=get_embedding_client()) + + +IndexerDependency = Annotated[ArtifactIndexer, Depends(get_indexer)] diff --git a/app/embeddings/retriever.py b/app/embeddings/retriever.py new file mode 100644 index 0000000..b3ccd00 --- /dev/null +++ b/app/embeddings/retriever.py @@ -0,0 +1,221 @@ +"""FAISS-based retrieval with adaptive selection.""" + +import json +import logging +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path +from typing import Annotated + +import faiss +import numpy as np +from fastapi import Depends + +from app.config import settings +from app.embeddings.client import EmbeddingClient, get_embedding_client + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievedChunk: + """A chunk retrieved from FAISS search.""" + + chunk_id: str + content: str + chunk_type: str + artifact_file: str + source_file: str + tags: dict + score: float + + +class Retriever: + """FAISS-based retriever with adaptive selection logic.""" + + def __init__(self, embedding_client: EmbeddingClient): + """Initialize the retriever. + + Args: + embedding_client: Client for generating query embeddings + """ + self.embedding_client = embedding_client + self.embeddings_path = Path(settings.embeddings_path) + self.top_k = settings.rag_top_k + self.threshold = settings.rag_similarity_threshold + + self._index: faiss.IndexFlatIP | None = None + self._metadata: list[dict] | None = None + self._loaded = False + + def load_index(self) -> bool: + """Load FAISS index and metadata from disk. + + Returns: + True if successfully loaded, False otherwise + """ + index_path = self.embeddings_path / "faiss_index.bin" + metadata_path = self.embeddings_path / "metadata.json" + + if not index_path.exists() or not metadata_path.exists(): + logger.warning("FAISS index or metadata not found. Run /index first.") + self._loaded = False + return False + + try: + self._index = faiss.read_index(str(index_path)) + + with open(metadata_path, "r", encoding="utf-8") as f: + data = json.load(f) + self._metadata = data.get("chunks", []) + + self._loaded = True + logger.info(f"Loaded FAISS index with {self._index.ntotal} vectors") + return True + + except Exception as e: + logger.error(f"Failed to load FAISS index: {e}") + self._loaded = False + return False + + @property + def is_loaded(self) -> bool: + """Check if index is loaded.""" + return self._loaded and self._index is not None + + @property + def index_size(self) -> int: + """Get number of vectors in index.""" + if self._index is None: + return 0 + return self._index.ntotal + + def _adaptive_select( + self, + indices: np.ndarray, + scores: np.ndarray, + ) -> list[tuple[int, float]]: + """Apply adaptive selection logic. + + - Always include top 2 chunks (regardless of score) + - For chunks 3-5: apply threshold + - Limit to self.top_k chunks total + + Args: + indices: FAISS result indices + scores: FAISS result scores + + Returns: + List of (index, score) tuples for selected chunks + """ + selected = [] + + for i, (idx, score) in enumerate(zip(indices, scores)): + if idx == -1: # FAISS returns -1 for no match + continue + + # Always take top 2 + if i < 2: + selected.append((int(idx), float(score))) + # Apply threshold for remaining + elif score >= self.threshold and len(selected) < self.top_k: + selected.append((int(idx), float(score))) + + return selected + + def _apply_diversity_filter( + self, + candidates: list[tuple[int, float]], + max_per_artifact: int = 2, + ) -> list[tuple[int, float]]: + """Limit chunks per artifact for diversity. + + Args: + candidates: List of (index, score) tuples + max_per_artifact: Maximum chunks from same artifact + + Returns: + Filtered list of (index, score) tuples + """ + artifact_counts: dict[str, int] = {} + filtered = [] + + for idx, score in candidates: + chunk = self._metadata[idx] + artifact = chunk["artifact_file"] + + if artifact_counts.get(artifact, 0) < max_per_artifact: + filtered.append((idx, score)) + artifact_counts[artifact] = artifact_counts.get(artifact, 0) + 1 + + return filtered + + async def search(self, query: str) -> list[RetrievedChunk]: + """Search for relevant chunks. + + Args: + query: User's question + + Returns: + List of retrieved chunks with relevance scores + """ + if not self.is_loaded: + if not self.load_index(): + return [] + + # Generate query embedding + query_embedding = await self.embedding_client.embed(query) + query_vector = np.array([query_embedding], dtype=np.float32) + + # Normalize for cosine similarity + faiss.normalize_L2(query_vector) + + # Search FAISS (get more candidates than needed for filtering) + k_search = min(8, self._index.ntotal) + scores, indices = self._index.search(query_vector, k_search) + + # Apply adaptive selection + selected = self._adaptive_select(indices[0], scores[0]) + + # Apply diversity filter + filtered = self._apply_diversity_filter(selected) + + # Build result chunks + results = [] + for idx, score in filtered: + chunk_data = self._metadata[idx] + results.append(RetrievedChunk( + chunk_id=chunk_data["chunk_id"], + content=chunk_data["content"], + chunk_type=chunk_data["chunk_type"], + artifact_file=chunk_data["artifact_file"], + source_file=chunk_data["source_file"], + tags=chunk_data.get("tags", {}), + score=score, + )) + + logger.debug(f"Retrieved {len(results)} chunks for query") + return results + + +# Singleton retriever instance +_retriever: Retriever | None = None + + +def get_retriever() -> Retriever: + """Get singleton retriever instance (lazily initialized).""" + global _retriever + if _retriever is None: + _retriever = Retriever(embedding_client=get_embedding_client()) + # Attempt to load index at startup + _retriever.load_index() + return _retriever + + +def reset_retriever() -> None: + """Reset the singleton retriever (for reloading after re-indexing).""" + global _retriever + _retriever = None + + +RetrieverDependency = Annotated[Retriever, Depends(get_retriever)] diff --git a/app/intent/__init__.py b/app/intent/__init__.py new file mode 100644 index 0000000..444c63f --- /dev/null +++ b/app/intent/__init__.py @@ -0,0 +1,9 @@ +"""Intent classification module.""" + +from app.intent.classifier import IntentClassifier, get_intent_classifier, IntentClassifierDependency + +__all__ = [ + "IntentClassifier", + "get_intent_classifier", + "IntentClassifierDependency", +] diff --git a/app/intent/classifier.py b/app/intent/classifier.py new file mode 100644 index 0000000..3502b65 --- /dev/null +++ b/app/intent/classifier.py @@ -0,0 +1,96 @@ +"""Lightweight intent classification using gpt-4o-mini.""" + +import logging +from functools import lru_cache +from typing import Annotated, Literal + +from fastapi import Depends +from openai import AsyncOpenAI + +from app.config import settings +from app.memory.conversation import Message + +logger = logging.getLogger(__name__) + +Intent = Literal["codebase", "general", "clarification"] + +INTENT_PROMPT = """Classify this user message into one category: +- "codebase": Questions about trading system code, architecture, files, methods, execution, strategies, exchanges, risk management, order handling, or technical implementation +- "general": Greetings, meta-questions, off-topic ("How are you?", "What can you do?", "Hello") +- "clarification": Follow-ups that rely on conversation context, not new retrieval ("Tell me more", "What did you mean?", "Can you explain that?") + +IMPORTANT: If the user is asking about specific code, files, classes, methods, or system behavior, classify as "codebase". + +Respond with ONLY the category name, nothing else.""" + + +class IntentClassifier: + """Lightweight intent classifier using gpt-4o-mini.""" + + def __init__(self, api_key: str): + """Initialize the classifier. + + Args: + api_key: OpenAI API key + """ + self.client = AsyncOpenAI(api_key=api_key) + self.model = "gpt-4o-mini" + + async def classify( + self, + message: str, + history: list[Message] | None = None, + ) -> Intent: + """Classify user message intent. + + Args: + message: User's message + history: Optional conversation history for context + + Returns: + Classified intent: "codebase", "general", or "clarification" + """ + # Build context from history (last 2 turns) + context = "" + if history and len(history) >= 2: + recent = history[-4:] # Last 2 exchanges + context = "Recent conversation:\n" + for msg in recent: + role = "User" if msg.role == "user" else "Assistant" + context += f"{role}: {msg.content[:100]}...\n" if len(msg.content) > 100 else f"{role}: {msg.content}\n" + context += "\n" + + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": INTENT_PROMPT}, + {"role": "user", "content": f"{context}Current message: {message}"}, + ], + max_tokens=10, + temperature=0, + ) + + raw_intent = response.choices[0].message.content.strip().lower() + + # Validate intent + if raw_intent in ("codebase", "general", "clarification"): + logger.debug(f"Classified intent: {raw_intent}") + return raw_intent + + # Default to codebase for ambiguous cases (safer for RAG) + logger.warning(f"Unexpected intent response: {raw_intent}, defaulting to codebase") + return "codebase" + + except Exception as e: + logger.warning(f"Intent classification failed: {e}, defaulting to codebase") + return "codebase" + + +@lru_cache() +def get_intent_classifier() -> IntentClassifier: + """Get cached intent classifier instance.""" + return IntentClassifier(api_key=settings.openai_api_key) + + +IntentClassifierDependency = Annotated[IntentClassifier, Depends(get_intent_classifier)] diff --git a/app/main.py b/app/main.py index 175b12a..c88d992 100644 --- a/app/main.py +++ b/app/main.py @@ -1,17 +1,24 @@ import logging import uuid -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from app.auth import ServiceAuthDependency from app.config import settings, MAX_MESSAGE_LENGTH +from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever +from app.embeddings.retriever import RetrievedChunk +from app.intent import IntentClassifierDependency from app.llm import AdapterDependency, LLMError, llm_exception_to_http +from app.memory import ConversationMemoryDependency +from app.prompts import build_rag_prompt from app.schemas import ( ChatRequest, ChatResponse, HealthResponse, + IndexResponse, + SourceReference, StreamChunkEvent, StreamDoneEvent, StreamErrorEvent, @@ -43,21 +50,101 @@ app.add_middleware( @app.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: - """Health check endpoint.""" - return HealthResponse(status="ok") + """Health check endpoint with FAISS status.""" + retriever = get_retriever() + + return HealthResponse( + status="ok", + faiss_loaded=retriever.is_loaded, + index_size=retriever.index_size, + ) + + +@app.get("/debug/search") +async def debug_search( + q: str, + retriever: RetrieverDependency, +) -> dict: + """Debug endpoint to test retrieval directly.""" + chunks = await retriever.search(q) + return { + "query": q, + "chunks_found": len(chunks), + "chunks": [ + { + "source_file": c.source_file, + "chunk_type": c.chunk_type, + "score": round(c.score, 4), + "content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content, + } + for c in chunks + ], + } + + +@app.post("/index", response_model=IndexResponse) +async def reindex_artifacts( + background_tasks: BackgroundTasks, + indexer: IndexerDependency, + _auth: ServiceAuthDependency, +) -> IndexResponse: + """Trigger re-indexing of YAML artifacts. + + Builds FAISS index from all YAML files in the artifacts directory. + """ + logger.info("Starting artifact indexing...") + + try: + result = await indexer.build_index() + + # Reset retriever to reload new index + reset_retriever() + + logger.info( + f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts" + ) + + return IndexResponse( + status=result.status, + chunks_indexed=result.chunks_indexed, + artifacts_processed=result.artifacts_processed, + ) + + except Exception as e: + logger.error(f"Indexing failed: {e}") + raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}") + + +def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]: + """Convert retrieved chunks to source references.""" + return [ + SourceReference( + artifact_file=chunk.artifact_file, + source_file=chunk.source_file, + chunk_type=chunk.chunk_type, + relevance_score=chunk.score, + ) + for chunk in chunks + ] @app.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, adapter: AdapterDependency, + retriever: RetrieverDependency, + memory: ConversationMemoryDependency, + classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> ChatResponse: - """Process a chat message through the LLM adapter. + """Process a chat message through the RAG pipeline. - Validates message length - Generates conversation_id if not provided - - Routes to appropriate LLM adapter based on LLM_MODE + - Classifies intent (codebase/general/clarification) + - Retrieves relevant context from FAISS (for codebase intent) + - Builds RAG prompt and generates response + - Stores conversation turn in Redis """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: @@ -70,19 +157,42 @@ async def chat( # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) - # Log request metadata (not content) + # Get conversation history + history = await memory.get_history(conversation_id) + + # Classify intent + intent = await classifier.classify(request.message, history) + + # Log request metadata logger.info( "Chat request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, + "intent": intent, }, ) + # Retrieve context for codebase questions + chunks: list[RetrievedChunk] = [] + if intent == "codebase": + chunks = await retriever.search(request.message) + logger.debug(f"Retrieved {len(chunks)} chunks for codebase question") + + # Build RAG prompt + system_prompt, user_content = build_rag_prompt( + user_message=request.message, + intent=intent, + chunks=chunks, + history=history, + ) + # Generate response with exception handling try: - response_text = await adapter.generate(conversation_id, request.message) + # For RAG, we pass the full constructed prompt + full_prompt = f"{system_prompt}\n\n{user_content}" + response_text = await adapter.generate(conversation_id, full_prompt) except LLMError as e: logger.error( "LLM generation failed", @@ -94,6 +204,15 @@ async def chat( ) raise llm_exception_to_http(e) + # Store conversation turn + sources = _chunks_to_sources(chunks) + await memory.store_turn( + conversation_id=conversation_id, + user_message=request.message, + assistant_message=response_text, + sources=[s.model_dump() for s in sources] if sources else None, + ) + # Log response metadata logger.info( "Chat response generated", @@ -101,6 +220,8 @@ async def chat( "conversation_id": conversation_id, "response_length": len(response_text), "mode": settings.llm_mode, + "intent": intent, + "sources_count": len(sources), }, ) @@ -108,7 +229,8 @@ async def chat( conversation_id=conversation_id, response=response_text, mode=settings.llm_mode, - sources=[], + intent=intent, + sources=sources, ) @@ -116,14 +238,18 @@ async def chat( async def chat_stream( request: ChatRequest, adapter: AdapterDependency, + retriever: RetrieverDependency, + memory: ConversationMemoryDependency, + classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> StreamingResponse: - """Stream a chat response through the LLM adapter using Server-Sent Events. + """Stream a chat response through the RAG pipeline using Server-Sent Events. - Validates message length - Generates conversation_id if not provided - - Routes to appropriate LLM adapter based on LLM_MODE - - Returns streaming response with SSE format + - Classifies intent and retrieves context + - Streams response with SSE format + - Stores conversation turn after completion """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: @@ -136,26 +262,64 @@ async def chat_stream( # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) - # Log request metadata (not content) + # Get conversation history + history = await memory.get_history(conversation_id) + + # Classify intent + intent = await classifier.classify(request.message, history) + + # Retrieve context for codebase questions + chunks: list[RetrievedChunk] = [] + if intent == "codebase": + chunks = await retriever.search(request.message) + + # Build RAG prompt + system_prompt, user_content = build_rag_prompt( + user_message=request.message, + intent=intent, + chunks=chunks, + history=history, + ) + + sources = _chunks_to_sources(chunks) + + # Log request metadata logger.info( "Chat stream request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, + "intent": intent, }, ) async def stream_response(): """Async generator that yields SSE-formatted events.""" + full_response = [] + try: - async for chunk in adapter.generate_stream(conversation_id, request.message): + full_prompt = f"{system_prompt}\n\n{user_content}" + async for chunk in adapter.generate_stream(conversation_id, full_prompt): + full_response.append(chunk) event = StreamChunkEvent(content=chunk, conversation_id=conversation_id) yield f"data: {event.model_dump_json()}\n\n" - # Send completion event + # Store conversation turn + response_text = "".join(full_response) + await memory.store_turn( + conversation_id=conversation_id, + user_message=request.message, + assistant_message=response_text, + sources=[s.model_dump() for s in sources] if sources else None, + ) + + # Send completion event with sources done_event = StreamDoneEvent( - conversation_id=conversation_id, mode=settings.llm_mode + conversation_id=conversation_id, + mode=settings.llm_mode, + intent=intent, + sources=sources, ) yield f"data: {done_event.model_dump_json()}\n\n" @@ -164,6 +328,8 @@ async def chat_stream( extra={ "conversation_id": conversation_id, "mode": settings.llm_mode, + "intent": intent, + "sources_count": len(sources), }, ) diff --git a/app/memory/__init__.py b/app/memory/__init__.py new file mode 100644 index 0000000..6033754 --- /dev/null +++ b/app/memory/__init__.py @@ -0,0 +1,9 @@ +"""Memory module for conversation history management.""" + +from app.memory.conversation import ConversationMemory, get_conversation_memory, ConversationMemoryDependency + +__all__ = [ + "ConversationMemory", + "get_conversation_memory", + "ConversationMemoryDependency", +] diff --git a/app/memory/conversation.py b/app/memory/conversation.py new file mode 100644 index 0000000..d2b72e0 --- /dev/null +++ b/app/memory/conversation.py @@ -0,0 +1,122 @@ +"""Conversation history management with in-memory storage.""" + +import logging +import time +from dataclasses import dataclass, asdict, field +from typing import Annotated + +from fastapi import Depends + +from app.config import settings + +logger = logging.getLogger(__name__) + +MAX_HISTORY_MESSAGES = 20 + + +@dataclass +class Message: + """A single message in conversation history.""" + + role: str # "user" or "assistant" + content: str + sources: list[dict] | None = None + + +@dataclass +class ConversationData: + """Container for conversation messages with timestamp for TTL.""" + + messages: list[Message] = field(default_factory=list) + last_updated: float = field(default_factory=time.time) + + +# Module-level storage for conversations +_conversations: dict[str, ConversationData] = {} + + +class ConversationMemory: + """Manages conversation history in memory.""" + + def __init__(self, ttl: int): + """Initialize conversation memory. + + Args: + ttl: Time-to-live in seconds for conversations + """ + self.ttl = ttl + + async def get_history(self, conversation_id: str) -> list[Message]: + """Get conversation history. + + Args: + conversation_id: Conversation identifier + + Returns: + List of messages in chronological order, or empty list if expired/not found + """ + data = _conversations.get(conversation_id) + if data is None: + return [] + + # Check if expired + if time.time() - data.last_updated > self.ttl: + del _conversations[conversation_id] + return [] + + return data.messages + + async def store_turn( + self, + conversation_id: str, + user_message: str, + assistant_message: str, + sources: list[dict] | None = None, + ) -> None: + """Store a conversation turn (user message + assistant response). + + Args: + conversation_id: Conversation identifier + user_message: User's message + assistant_message: Assistant's response + sources: Optional source references used in response + """ + # Get existing history (checks TTL) + history = await self.get_history(conversation_id) + + # Add new messages + history.append(Message(role="user", content=user_message)) + history.append(Message(role="assistant", content=assistant_message, sources=sources)) + + # Trim to max size (keep most recent) + if len(history) > MAX_HISTORY_MESSAGES: + history = history[-MAX_HISTORY_MESSAGES:] + + # Store with updated timestamp + _conversations[conversation_id] = ConversationData( + messages=history, + last_updated=time.time(), + ) + + logger.debug(f"Stored conversation turn for {conversation_id}") + + async def clear(self, conversation_id: str) -> bool: + """Clear conversation history. + + Args: + conversation_id: Conversation identifier + + Returns: + True if cleared successfully + """ + if conversation_id in _conversations: + del _conversations[conversation_id] + return True + + +def get_conversation_memory() -> ConversationMemory: + """Get conversation memory instance.""" + return ConversationMemory(ttl=settings.conversation_ttl) + + +ConversationMemoryDependency = Annotated[ConversationMemory, Depends(get_conversation_memory)] diff --git a/app/prompts.py b/app/prompts.py new file mode 100644 index 0000000..87b7d37 --- /dev/null +++ b/app/prompts.py @@ -0,0 +1,136 @@ +"""System prompts for RAG-based codebase Q&A.""" + +from app.embeddings.retriever import RetrievedChunk +from app.memory.conversation import Message + +# For codebase questions WITH retrieved context +CODEBASE_SYSTEM_PROMPT = """You answer questions about the Tyndale trading system using ONLY the provided YAML artifacts. + +HARD CONSTRAINTS: +- Do NOT assume access to source code +- Do NOT invent implementation details not in the artifacts +- Do NOT speculate about code mechanics beyond what artifacts describe +- If artifacts do not contain enough information, say so explicitly + +RESPONSE STYLE: +- Prefer architectural and behavioral explanations over mechanics +- Reference source files by path (e.g., ./trader.py) +- Explain trading concepts for developers without finance background +- Keep responses focused and concise""" + +# For general questions (no RAG context) +GENERAL_SYSTEM_PROMPT = """You are an assistant for the Tyndale trading system documentation. +You can answer general questions, but for specific codebase questions, you need artifact context. +If the user asks about specific code without context, ask them to rephrase or be more specific.""" + +# For clarification/follow-ups (uses conversation history) +CLARIFICATION_SYSTEM_PROMPT = """You are continuing a conversation about the Tyndale trading system. +Use the conversation history to answer follow-up questions. +If you need to look up new information, ask the user to rephrase as a standalone question.""" + + +def select_system_prompt(intent: str, has_context: bool) -> str: + """Select appropriate system prompt based on intent and context. + + Args: + intent: Classified intent (codebase, general, clarification) + has_context: Whether RAG context was retrieved + + Returns: + System prompt string + """ + if intent == "codebase" and has_context: + return CODEBASE_SYSTEM_PROMPT + elif intent == "codebase" and not has_context: + return CODEBASE_SYSTEM_PROMPT + "\n\nNOTE: No relevant artifacts were found for this question. Acknowledge this limitation in your response." + elif intent == "clarification": + return CLARIFICATION_SYSTEM_PROMPT + else: + return GENERAL_SYSTEM_PROMPT + + +def format_context(chunks: list[RetrievedChunk]) -> str: + """Format retrieved chunks as context for the LLM. + + Args: + chunks: List of retrieved chunks + + Returns: + Formatted context string + """ + if not chunks: + return "" + + context_parts = ["## Retrieved Artifact Context\n"] + + for i, chunk in enumerate(chunks, 1): + context_parts.append(f"### Source {i}: {chunk.source_file} ({chunk.chunk_type})") + context_parts.append(chunk.content) + context_parts.append("") # Empty line separator + + return "\n".join(context_parts) + + +def format_history(history: list[Message], max_messages: int = 10) -> str: + """Format conversation history for the LLM. + + Args: + history: List of conversation messages + max_messages: Maximum messages to include + + Returns: + Formatted history string + """ + if not history: + return "" + + # Take most recent messages + recent = history[-max_messages:] if len(history) > max_messages else history + + history_parts = ["## Conversation History\n"] + + for msg in recent: + role = "User" if msg.role == "user" else "Assistant" + history_parts.append(f"**{role}**: {msg.content}") + history_parts.append("") + + return "\n".join(history_parts) + + +def build_rag_prompt( + user_message: str, + intent: str, + chunks: list[RetrievedChunk], + history: list[Message], +) -> tuple[str, str]: + """Build complete RAG prompt with system message and user content. + + Args: + user_message: Current user message + intent: Classified intent + chunks: Retrieved context chunks + history: Conversation history + + Returns: + Tuple of (system_prompt, user_content) + """ + # Select system prompt + system_prompt = select_system_prompt(intent, bool(chunks)) + + # Build user content + parts = [] + + # Add context if available + if chunks: + parts.append(format_context(chunks)) + + # Add history for clarification intent + if intent == "clarification" and history: + parts.append(format_history(history)) + + # Add current question + parts.append(f"## Current Question\n{user_message}") + + user_content = "\n\n".join(parts) + + return system_prompt, user_content diff --git a/app/schemas.py b/app/schemas.py index f645920..46aeff4 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -3,6 +3,15 @@ from typing import Literal from pydantic import BaseModel, Field +class SourceReference(BaseModel): + """Reference to a source artifact chunk used in RAG response.""" + + artifact_file: str = Field(..., description="Path to the YAML artifact file") + source_file: str = Field(..., description="Path to the source code file") + chunk_type: str = Field(..., description="Type of chunk (factual_summary, interpretive_summary, method, invariants)") + relevance_score: float = Field(..., description="Similarity score from FAISS search") + + class ChatRequest(BaseModel): """Request model for the /chat endpoint.""" @@ -18,13 +27,18 @@ class ChatResponse(BaseModel): conversation_id: str = Field(..., description="Conversation ID (generated if not provided)") response: str = Field(..., description="The LLM's response") mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used") - sources: list = Field(default_factory=list, description="Source references (empty for now)") + intent: Literal["codebase", "general", "clarification"] = Field( + default="general", description="Classified intent of the user message" + ) + sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response") class HealthResponse(BaseModel): """Response model for the /health endpoint.""" status: str = Field(default="ok") + faiss_loaded: bool = Field(default=False, description="Whether FAISS index is loaded") + index_size: int = Field(default=0, description="Number of chunks in FAISS index") class ErrorResponse(BaseModel): @@ -50,6 +64,18 @@ class StreamDoneEvent(BaseModel): type: Literal["done"] = "done" conversation_id: str = Field(..., description="Conversation ID") mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used") + intent: Literal["codebase", "general", "clarification"] = Field( + default="general", description="Classified intent of the user message" + ) + sources: list[SourceReference] = Field(default_factory=list, description="Source artifact references used in response") + + +class IndexResponse(BaseModel): + """Response model for the /index endpoint.""" + + status: str = Field(..., description="Indexing status") + chunks_indexed: int = Field(default=0, description="Number of chunks indexed") + artifacts_processed: int = Field(default=0, description="Number of YAML files processed") class StreamErrorEvent(BaseModel): diff --git a/requirements.txt b/requirements.txt index 5066cb7..e55cd8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,4 @@ urllib3==2.6.3 uvicorn==0.40.0 watchfiles==1.1.1 websockets==16.0 +faiss-cpu>=1.12.0