import logging import uuid from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from app.auth import ServiceAuthDependency from app.config import settings, MAX_MESSAGE_LENGTH from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever, get_indexer from app.embeddings.retriever import RetrievedChunk from app.intent import IntentClassifierDependency from app.llm import AdapterDependency, LLMError, llm_exception_to_http from app.memory import ConversationMemoryDependency from app.prompts import build_rag_prompt from app.schemas import ( ChatRequest, ChatResponse, HealthResponse, IndexResponse, SourceReference, StreamChunkEvent, StreamDoneEvent, StreamErrorEvent, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown lifecycle management.""" # Startup: auto-index if FAISS index is missing retriever = get_retriever() if not retriever.is_loaded: logger.info("FAISS index not found, building index on startup...") try: indexer = get_indexer() result = await indexer.build_index() reset_retriever() # Reset to load newly built index logger.info(f"Startup indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts") except Exception as e: logger.error(f"Startup indexing failed: {e}") else: logger.info(f"FAISS index loaded: {retriever.index_size} vectors") yield # App runs here # Shutdown: nothing to clean up # Create FastAPI app app = FastAPI( title="Tyndale AI Service", description="LLM Chat Service for algorithmic trading support", version="0.1.0", lifespan=lifespan, ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint with FAISS status.""" retriever = get_retriever() return HealthResponse( status="ok", faiss_loaded=retriever.is_loaded, index_size=retriever.index_size, ) @app.get("/debug/search") async def debug_search( q: str, retriever: RetrieverDependency, ) -> dict: """Debug endpoint to test retrieval directly.""" chunks = await retriever.search(q) return { "query": q, "chunks_found": len(chunks), "chunks": [ { "source_file": c.source_file, "chunk_type": c.chunk_type, "score": round(c.score, 4), "content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content, } for c in chunks ], } @app.post("/index", response_model=IndexResponse) async def reindex_artifacts( background_tasks: BackgroundTasks, indexer: IndexerDependency, _auth: ServiceAuthDependency, ) -> IndexResponse: """Trigger re-indexing of YAML artifacts. Builds FAISS index from all YAML files in the artifacts directory. """ logger.info("Starting artifact indexing...") try: result = await indexer.build_index() # Reset retriever to reload new index reset_retriever() logger.info( f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts" ) return IndexResponse( status=result.status, chunks_indexed=result.chunks_indexed, artifacts_processed=result.artifacts_processed, ) except Exception as e: logger.error(f"Indexing failed: {e}") raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}") def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]: """Convert retrieved chunks to source references.""" return [ SourceReference( artifact_file=chunk.artifact_file, source_file=chunk.source_file, chunk_type=chunk.chunk_type, relevance_score=chunk.score, ) for chunk in chunks ] @app.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, adapter: AdapterDependency, retriever: RetrieverDependency, memory: ConversationMemoryDependency, classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> ChatResponse: """Process a chat message through the RAG pipeline. - Validates message length - Generates conversation_id if not provided - Classifies intent (codebase/general/clarification) - Retrieves relevant context from FAISS (for codebase intent) - Builds RAG prompt and generates response - Stores conversation turn in Redis """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. " f"Your message has {len(request.message):,} characters.", ) # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) # Get conversation history history = await memory.get_history(conversation_id) # Classify intent intent = await classifier.classify(request.message, history) # Log request metadata logger.info( "Chat request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, "intent": intent, }, ) # Retrieve context for codebase questions chunks: list[RetrievedChunk] = [] if intent == "codebase": chunks = await retriever.search(request.message) logger.debug(f"Retrieved {len(chunks)} chunks for codebase question") # Build RAG prompt system_prompt, user_content = build_rag_prompt( user_message=request.message, intent=intent, chunks=chunks, history=history, ) # Generate response with exception handling try: # For RAG, we pass the full constructed prompt full_prompt = f"{system_prompt}\n\n{user_content}" response_text = await adapter.generate(conversation_id, full_prompt) except LLMError as e: logger.error( "LLM generation failed", extra={ "conversation_id": conversation_id, "error_type": type(e).__name__, "error_message": e.message, }, ) raise llm_exception_to_http(e) # Store conversation turn sources = _chunks_to_sources(chunks) await memory.store_turn( conversation_id=conversation_id, user_message=request.message, assistant_message=response_text, sources=[s.model_dump() for s in sources] if sources else None, ) # Log response metadata logger.info( "Chat response generated", extra={ "conversation_id": conversation_id, "response_length": len(response_text), "mode": settings.llm_mode, "intent": intent, "sources_count": len(sources), }, ) return ChatResponse( conversation_id=conversation_id, response=response_text, mode=settings.llm_mode, intent=intent, sources=sources, ) @app.post("/chat/stream") async def chat_stream( request: ChatRequest, adapter: AdapterDependency, retriever: RetrieverDependency, memory: ConversationMemoryDependency, classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> StreamingResponse: """Stream a chat response through the RAG pipeline using Server-Sent Events. - Validates message length - Generates conversation_id if not provided - Classifies intent and retrieves context - Streams response with SSE format - Stores conversation turn after completion """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. " f"Your message has {len(request.message):,} characters.", ) # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) # Get conversation history history = await memory.get_history(conversation_id) # Classify intent intent = await classifier.classify(request.message, history) # Retrieve context for codebase questions chunks: list[RetrievedChunk] = [] if intent == "codebase": chunks = await retriever.search(request.message) # Build RAG prompt system_prompt, user_content = build_rag_prompt( user_message=request.message, intent=intent, chunks=chunks, history=history, ) sources = _chunks_to_sources(chunks) # Log request metadata logger.info( "Chat stream request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, "intent": intent, }, ) async def stream_response(): """Async generator that yields SSE-formatted events.""" full_response = [] try: full_prompt = f"{system_prompt}\n\n{user_content}" async for chunk in adapter.generate_stream(conversation_id, full_prompt): full_response.append(chunk) event = StreamChunkEvent(content=chunk, conversation_id=conversation_id) yield f"data: {event.model_dump_json()}\n\n" # Store conversation turn response_text = "".join(full_response) await memory.store_turn( conversation_id=conversation_id, user_message=request.message, assistant_message=response_text, sources=[s.model_dump() for s in sources] if sources else None, ) # Send completion event with sources done_event = StreamDoneEvent( conversation_id=conversation_id, mode=settings.llm_mode, intent=intent, sources=sources, ) yield f"data: {done_event.model_dump_json()}\n\n" logger.info( "Chat stream completed", extra={ "conversation_id": conversation_id, "mode": settings.llm_mode, "intent": intent, "sources_count": len(sources), }, ) except LLMError as e: logger.error( "LLM streaming failed", extra={ "conversation_id": conversation_id, "error_type": type(e).__name__, "error_message": e.message, }, ) error_event = StreamErrorEvent(message=e.message, code=e.status_code) yield f"data: {error_event.model_dump_json()}\n\n" return StreamingResponse( stream_response(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)