tyndale-ai-service/app/main.py

361 lines
11 KiB
Python

import logging
import uuid
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from app.auth import ServiceAuthDependency
from app.config import settings, MAX_MESSAGE_LENGTH
from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever
from app.embeddings.retriever import RetrievedChunk
from app.intent import IntentClassifierDependency
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.memory import ConversationMemoryDependency
from app.prompts import build_rag_prompt
from app.schemas import (
ChatRequest,
ChatResponse,
HealthResponse,
IndexResponse,
SourceReference,
StreamChunkEvent,
StreamDoneEvent,
StreamErrorEvent,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="Tyndale AI Service",
description="LLM Chat Service for algorithmic trading support",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Health check endpoint with FAISS status."""
retriever = get_retriever()
return HealthResponse(
status="ok",
faiss_loaded=retriever.is_loaded,
index_size=retriever.index_size,
)
@app.get("/debug/search")
async def debug_search(
q: str,
retriever: RetrieverDependency,
) -> dict:
"""Debug endpoint to test retrieval directly."""
chunks = await retriever.search(q)
return {
"query": q,
"chunks_found": len(chunks),
"chunks": [
{
"source_file": c.source_file,
"chunk_type": c.chunk_type,
"score": round(c.score, 4),
"content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content,
}
for c in chunks
],
}
@app.post("/index", response_model=IndexResponse)
async def reindex_artifacts(
background_tasks: BackgroundTasks,
indexer: IndexerDependency,
_auth: ServiceAuthDependency,
) -> IndexResponse:
"""Trigger re-indexing of YAML artifacts.
Builds FAISS index from all YAML files in the artifacts directory.
"""
logger.info("Starting artifact indexing...")
try:
result = await indexer.build_index()
# Reset retriever to reload new index
reset_retriever()
logger.info(
f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts"
)
return IndexResponse(
status=result.status,
chunks_indexed=result.chunks_indexed,
artifacts_processed=result.artifacts_processed,
)
except Exception as e:
logger.error(f"Indexing failed: {e}")
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]:
"""Convert retrieved chunks to source references."""
return [
SourceReference(
artifact_file=chunk.artifact_file,
source_file=chunk.source_file,
chunk_type=chunk.chunk_type,
relevance_score=chunk.score,
)
for chunk in chunks
]
@app.post("/chat", response_model=ChatResponse)
async def chat(
request: ChatRequest,
adapter: AdapterDependency,
retriever: RetrieverDependency,
memory: ConversationMemoryDependency,
classifier: IntentClassifierDependency,
_auth: ServiceAuthDependency,
) -> ChatResponse:
"""Process a chat message through the RAG pipeline.
- Validates message length
- Generates conversation_id if not provided
- Classifies intent (codebase/general/clarification)
- Retrieves relevant context from FAISS (for codebase intent)
- Builds RAG prompt and generates response
- Stores conversation turn in Redis
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Get conversation history
history = await memory.get_history(conversation_id)
# Classify intent
intent = await classifier.classify(request.message, history)
# Log request metadata
logger.info(
"Chat request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
"intent": intent,
},
)
# Retrieve context for codebase questions
chunks: list[RetrievedChunk] = []
if intent == "codebase":
chunks = await retriever.search(request.message)
logger.debug(f"Retrieved {len(chunks)} chunks for codebase question")
# Build RAG prompt
system_prompt, user_content = build_rag_prompt(
user_message=request.message,
intent=intent,
chunks=chunks,
history=history,
)
# Generate response with exception handling
try:
# For RAG, we pass the full constructed prompt
full_prompt = f"{system_prompt}\n\n{user_content}"
response_text = await adapter.generate(conversation_id, full_prompt)
except LLMError as e:
logger.error(
"LLM generation failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
raise llm_exception_to_http(e)
# Store conversation turn
sources = _chunks_to_sources(chunks)
await memory.store_turn(
conversation_id=conversation_id,
user_message=request.message,
assistant_message=response_text,
sources=[s.model_dump() for s in sources] if sources else None,
)
# Log response metadata
logger.info(
"Chat response generated",
extra={
"conversation_id": conversation_id,
"response_length": len(response_text),
"mode": settings.llm_mode,
"intent": intent,
"sources_count": len(sources),
},
)
return ChatResponse(
conversation_id=conversation_id,
response=response_text,
mode=settings.llm_mode,
intent=intent,
sources=sources,
)
@app.post("/chat/stream")
async def chat_stream(
request: ChatRequest,
adapter: AdapterDependency,
retriever: RetrieverDependency,
memory: ConversationMemoryDependency,
classifier: IntentClassifierDependency,
_auth: ServiceAuthDependency,
) -> StreamingResponse:
"""Stream a chat response through the RAG pipeline using Server-Sent Events.
- Validates message length
- Generates conversation_id if not provided
- Classifies intent and retrieves context
- Streams response with SSE format
- Stores conversation turn after completion
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Get conversation history
history = await memory.get_history(conversation_id)
# Classify intent
intent = await classifier.classify(request.message, history)
# Retrieve context for codebase questions
chunks: list[RetrievedChunk] = []
if intent == "codebase":
chunks = await retriever.search(request.message)
# Build RAG prompt
system_prompt, user_content = build_rag_prompt(
user_message=request.message,
intent=intent,
chunks=chunks,
history=history,
)
sources = _chunks_to_sources(chunks)
# Log request metadata
logger.info(
"Chat stream request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
"intent": intent,
},
)
async def stream_response():
"""Async generator that yields SSE-formatted events."""
full_response = []
try:
full_prompt = f"{system_prompt}\n\n{user_content}"
async for chunk in adapter.generate_stream(conversation_id, full_prompt):
full_response.append(chunk)
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
yield f"data: {event.model_dump_json()}\n\n"
# Store conversation turn
response_text = "".join(full_response)
await memory.store_turn(
conversation_id=conversation_id,
user_message=request.message,
assistant_message=response_text,
sources=[s.model_dump() for s in sources] if sources else None,
)
# Send completion event with sources
done_event = StreamDoneEvent(
conversation_id=conversation_id,
mode=settings.llm_mode,
intent=intent,
sources=sources,
)
yield f"data: {done_event.model_dump_json()}\n\n"
logger.info(
"Chat stream completed",
extra={
"conversation_id": conversation_id,
"mode": settings.llm_mode,
"intent": intent,
"sources_count": len(sources),
},
)
except LLMError as e:
logger.error(
"LLM streaming failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)