import logging import uuid from fastapi import FastAPI, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from app.auth import ServiceAuthDependency from app.config import settings, MAX_MESSAGE_LENGTH from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever from app.embeddings.retriever import RetrievedChunk from app.intent import IntentClassifierDependency from app.llm import AdapterDependency, LLMError, llm_exception_to_http from app.memory import ConversationMemoryDependency from app.prompts import build_rag_prompt from app.schemas import ( ChatRequest, ChatResponse, HealthResponse, IndexResponse, SourceReference, StreamChunkEvent, StreamDoneEvent, StreamErrorEvent, ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Create FastAPI app app = FastAPI( title="Tyndale AI Service", description="LLM Chat Service for algorithmic trading support", version="0.1.0", ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint with FAISS status.""" retriever = get_retriever() return HealthResponse( status="ok", faiss_loaded=retriever.is_loaded, index_size=retriever.index_size, ) @app.get("/debug/search") async def debug_search( q: str, retriever: RetrieverDependency, ) -> dict: """Debug endpoint to test retrieval directly.""" chunks = await retriever.search(q) return { "query": q, "chunks_found": len(chunks), "chunks": [ { "source_file": c.source_file, "chunk_type": c.chunk_type, "score": round(c.score, 4), "content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content, } for c in chunks ], } @app.post("/index", response_model=IndexResponse) async def reindex_artifacts( background_tasks: BackgroundTasks, indexer: IndexerDependency, _auth: ServiceAuthDependency, ) -> IndexResponse: """Trigger re-indexing of YAML artifacts. Builds FAISS index from all YAML files in the artifacts directory. """ logger.info("Starting artifact indexing...") try: result = await indexer.build_index() # Reset retriever to reload new index reset_retriever() logger.info( f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts" ) return IndexResponse( status=result.status, chunks_indexed=result.chunks_indexed, artifacts_processed=result.artifacts_processed, ) except Exception as e: logger.error(f"Indexing failed: {e}") raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}") def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]: """Convert retrieved chunks to source references.""" return [ SourceReference( artifact_file=chunk.artifact_file, source_file=chunk.source_file, chunk_type=chunk.chunk_type, relevance_score=chunk.score, ) for chunk in chunks ] @app.post("/chat", response_model=ChatResponse) async def chat( request: ChatRequest, adapter: AdapterDependency, retriever: RetrieverDependency, memory: ConversationMemoryDependency, classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> ChatResponse: """Process a chat message through the RAG pipeline. - Validates message length - Generates conversation_id if not provided - Classifies intent (codebase/general/clarification) - Retrieves relevant context from FAISS (for codebase intent) - Builds RAG prompt and generates response - Stores conversation turn in Redis """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. " f"Your message has {len(request.message):,} characters.", ) # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) # Get conversation history history = await memory.get_history(conversation_id) # Classify intent intent = await classifier.classify(request.message, history) # Log request metadata logger.info( "Chat request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, "intent": intent, }, ) # Retrieve context for codebase questions chunks: list[RetrievedChunk] = [] if intent == "codebase": chunks = await retriever.search(request.message) logger.debug(f"Retrieved {len(chunks)} chunks for codebase question") # Build RAG prompt system_prompt, user_content = build_rag_prompt( user_message=request.message, intent=intent, chunks=chunks, history=history, ) # Generate response with exception handling try: # For RAG, we pass the full constructed prompt full_prompt = f"{system_prompt}\n\n{user_content}" response_text = await adapter.generate(conversation_id, full_prompt) except LLMError as e: logger.error( "LLM generation failed", extra={ "conversation_id": conversation_id, "error_type": type(e).__name__, "error_message": e.message, }, ) raise llm_exception_to_http(e) # Store conversation turn sources = _chunks_to_sources(chunks) await memory.store_turn( conversation_id=conversation_id, user_message=request.message, assistant_message=response_text, sources=[s.model_dump() for s in sources] if sources else None, ) # Log response metadata logger.info( "Chat response generated", extra={ "conversation_id": conversation_id, "response_length": len(response_text), "mode": settings.llm_mode, "intent": intent, "sources_count": len(sources), }, ) return ChatResponse( conversation_id=conversation_id, response=response_text, mode=settings.llm_mode, intent=intent, sources=sources, ) @app.post("/chat/stream") async def chat_stream( request: ChatRequest, adapter: AdapterDependency, retriever: RetrieverDependency, memory: ConversationMemoryDependency, classifier: IntentClassifierDependency, _auth: ServiceAuthDependency, ) -> StreamingResponse: """Stream a chat response through the RAG pipeline using Server-Sent Events. - Validates message length - Generates conversation_id if not provided - Classifies intent and retrieves context - Streams response with SSE format - Stores conversation turn after completion """ # Validate message length if len(request.message) > MAX_MESSAGE_LENGTH: raise HTTPException( status_code=400, detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. " f"Your message has {len(request.message):,} characters.", ) # Generate or use provided conversation_id conversation_id = request.conversation_id or str(uuid.uuid4()) # Get conversation history history = await memory.get_history(conversation_id) # Classify intent intent = await classifier.classify(request.message, history) # Retrieve context for codebase questions chunks: list[RetrievedChunk] = [] if intent == "codebase": chunks = await retriever.search(request.message) # Build RAG prompt system_prompt, user_content = build_rag_prompt( user_message=request.message, intent=intent, chunks=chunks, history=history, ) sources = _chunks_to_sources(chunks) # Log request metadata logger.info( "Chat stream request received", extra={ "conversation_id": conversation_id, "message_length": len(request.message), "mode": settings.llm_mode, "intent": intent, }, ) async def stream_response(): """Async generator that yields SSE-formatted events.""" full_response = [] try: full_prompt = f"{system_prompt}\n\n{user_content}" async for chunk in adapter.generate_stream(conversation_id, full_prompt): full_response.append(chunk) event = StreamChunkEvent(content=chunk, conversation_id=conversation_id) yield f"data: {event.model_dump_json()}\n\n" # Store conversation turn response_text = "".join(full_response) await memory.store_turn( conversation_id=conversation_id, user_message=request.message, assistant_message=response_text, sources=[s.model_dump() for s in sources] if sources else None, ) # Send completion event with sources done_event = StreamDoneEvent( conversation_id=conversation_id, mode=settings.llm_mode, intent=intent, sources=sources, ) yield f"data: {done_event.model_dump_json()}\n\n" logger.info( "Chat stream completed", extra={ "conversation_id": conversation_id, "mode": settings.llm_mode, "intent": intent, "sources_count": len(sources), }, ) except LLMError as e: logger.error( "LLM streaming failed", extra={ "conversation_id": conversation_id, "error_type": type(e).__name__, "error_message": e.message, }, ) error_event = StreamErrorEvent(message=e.message, code=e.status_code) yield f"data: {error_event.model_dump_json()}\n\n" return StreamingResponse( stream_response(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)