361 lines
11 KiB
Python
361 lines
11 KiB
Python
import logging
|
|
import uuid
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
from app.auth import ServiceAuthDependency
|
|
from app.config import settings, MAX_MESSAGE_LENGTH
|
|
from app.embeddings import RetrieverDependency, IndexerDependency, get_retriever, reset_retriever
|
|
from app.embeddings.retriever import RetrievedChunk
|
|
from app.intent import IntentClassifierDependency
|
|
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
|
from app.memory import ConversationMemoryDependency
|
|
from app.prompts import build_rag_prompt
|
|
from app.schemas import (
|
|
ChatRequest,
|
|
ChatResponse,
|
|
HealthResponse,
|
|
IndexResponse,
|
|
SourceReference,
|
|
StreamChunkEvent,
|
|
StreamDoneEvent,
|
|
StreamErrorEvent,
|
|
)
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="Tyndale AI Service",
|
|
description="LLM Chat Service for algorithmic trading support",
|
|
version="0.1.0",
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins_list,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check() -> HealthResponse:
|
|
"""Health check endpoint with FAISS status."""
|
|
retriever = get_retriever()
|
|
|
|
return HealthResponse(
|
|
status="ok",
|
|
faiss_loaded=retriever.is_loaded,
|
|
index_size=retriever.index_size,
|
|
)
|
|
|
|
|
|
@app.get("/debug/search")
|
|
async def debug_search(
|
|
q: str,
|
|
retriever: RetrieverDependency,
|
|
) -> dict:
|
|
"""Debug endpoint to test retrieval directly."""
|
|
chunks = await retriever.search(q)
|
|
return {
|
|
"query": q,
|
|
"chunks_found": len(chunks),
|
|
"chunks": [
|
|
{
|
|
"source_file": c.source_file,
|
|
"chunk_type": c.chunk_type,
|
|
"score": round(c.score, 4),
|
|
"content_preview": c.content[:200] + "..." if len(c.content) > 200 else c.content,
|
|
}
|
|
for c in chunks
|
|
],
|
|
}
|
|
|
|
|
|
@app.post("/index", response_model=IndexResponse)
|
|
async def reindex_artifacts(
|
|
background_tasks: BackgroundTasks,
|
|
indexer: IndexerDependency,
|
|
_auth: ServiceAuthDependency,
|
|
) -> IndexResponse:
|
|
"""Trigger re-indexing of YAML artifacts.
|
|
|
|
Builds FAISS index from all YAML files in the artifacts directory.
|
|
"""
|
|
logger.info("Starting artifact indexing...")
|
|
|
|
try:
|
|
result = await indexer.build_index()
|
|
|
|
# Reset retriever to reload new index
|
|
reset_retriever()
|
|
|
|
logger.info(
|
|
f"Indexing completed: {result.chunks_indexed} chunks from {result.artifacts_processed} artifacts"
|
|
)
|
|
|
|
return IndexResponse(
|
|
status=result.status,
|
|
chunks_indexed=result.chunks_indexed,
|
|
artifacts_processed=result.artifacts_processed,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Indexing failed: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
|
|
|
|
|
|
def _chunks_to_sources(chunks: list[RetrievedChunk]) -> list[SourceReference]:
|
|
"""Convert retrieved chunks to source references."""
|
|
return [
|
|
SourceReference(
|
|
artifact_file=chunk.artifact_file,
|
|
source_file=chunk.source_file,
|
|
chunk_type=chunk.chunk_type,
|
|
relevance_score=chunk.score,
|
|
)
|
|
for chunk in chunks
|
|
]
|
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse)
|
|
async def chat(
|
|
request: ChatRequest,
|
|
adapter: AdapterDependency,
|
|
retriever: RetrieverDependency,
|
|
memory: ConversationMemoryDependency,
|
|
classifier: IntentClassifierDependency,
|
|
_auth: ServiceAuthDependency,
|
|
) -> ChatResponse:
|
|
"""Process a chat message through the RAG pipeline.
|
|
|
|
- Validates message length
|
|
- Generates conversation_id if not provided
|
|
- Classifies intent (codebase/general/clarification)
|
|
- Retrieves relevant context from FAISS (for codebase intent)
|
|
- Builds RAG prompt and generates response
|
|
- Stores conversation turn in Redis
|
|
"""
|
|
# Validate message length
|
|
if len(request.message) > MAX_MESSAGE_LENGTH:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
|
|
f"Your message has {len(request.message):,} characters.",
|
|
)
|
|
|
|
# Generate or use provided conversation_id
|
|
conversation_id = request.conversation_id or str(uuid.uuid4())
|
|
|
|
# Get conversation history
|
|
history = await memory.get_history(conversation_id)
|
|
|
|
# Classify intent
|
|
intent = await classifier.classify(request.message, history)
|
|
|
|
# Log request metadata
|
|
logger.info(
|
|
"Chat request received",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"message_length": len(request.message),
|
|
"mode": settings.llm_mode,
|
|
"intent": intent,
|
|
},
|
|
)
|
|
|
|
# Retrieve context for codebase questions
|
|
chunks: list[RetrievedChunk] = []
|
|
if intent == "codebase":
|
|
chunks = await retriever.search(request.message)
|
|
logger.debug(f"Retrieved {len(chunks)} chunks for codebase question")
|
|
|
|
# Build RAG prompt
|
|
system_prompt, user_content = build_rag_prompt(
|
|
user_message=request.message,
|
|
intent=intent,
|
|
chunks=chunks,
|
|
history=history,
|
|
)
|
|
|
|
# Generate response with exception handling
|
|
try:
|
|
# For RAG, we pass the full constructed prompt
|
|
full_prompt = f"{system_prompt}\n\n{user_content}"
|
|
response_text = await adapter.generate(conversation_id, full_prompt)
|
|
except LLMError as e:
|
|
logger.error(
|
|
"LLM generation failed",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"error_type": type(e).__name__,
|
|
"error_message": e.message,
|
|
},
|
|
)
|
|
raise llm_exception_to_http(e)
|
|
|
|
# Store conversation turn
|
|
sources = _chunks_to_sources(chunks)
|
|
await memory.store_turn(
|
|
conversation_id=conversation_id,
|
|
user_message=request.message,
|
|
assistant_message=response_text,
|
|
sources=[s.model_dump() for s in sources] if sources else None,
|
|
)
|
|
|
|
# Log response metadata
|
|
logger.info(
|
|
"Chat response generated",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"response_length": len(response_text),
|
|
"mode": settings.llm_mode,
|
|
"intent": intent,
|
|
"sources_count": len(sources),
|
|
},
|
|
)
|
|
|
|
return ChatResponse(
|
|
conversation_id=conversation_id,
|
|
response=response_text,
|
|
mode=settings.llm_mode,
|
|
intent=intent,
|
|
sources=sources,
|
|
)
|
|
|
|
|
|
@app.post("/chat/stream")
|
|
async def chat_stream(
|
|
request: ChatRequest,
|
|
adapter: AdapterDependency,
|
|
retriever: RetrieverDependency,
|
|
memory: ConversationMemoryDependency,
|
|
classifier: IntentClassifierDependency,
|
|
_auth: ServiceAuthDependency,
|
|
) -> StreamingResponse:
|
|
"""Stream a chat response through the RAG pipeline using Server-Sent Events.
|
|
|
|
- Validates message length
|
|
- Generates conversation_id if not provided
|
|
- Classifies intent and retrieves context
|
|
- Streams response with SSE format
|
|
- Stores conversation turn after completion
|
|
"""
|
|
# Validate message length
|
|
if len(request.message) > MAX_MESSAGE_LENGTH:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
|
|
f"Your message has {len(request.message):,} characters.",
|
|
)
|
|
|
|
# Generate or use provided conversation_id
|
|
conversation_id = request.conversation_id or str(uuid.uuid4())
|
|
|
|
# Get conversation history
|
|
history = await memory.get_history(conversation_id)
|
|
|
|
# Classify intent
|
|
intent = await classifier.classify(request.message, history)
|
|
|
|
# Retrieve context for codebase questions
|
|
chunks: list[RetrievedChunk] = []
|
|
if intent == "codebase":
|
|
chunks = await retriever.search(request.message)
|
|
|
|
# Build RAG prompt
|
|
system_prompt, user_content = build_rag_prompt(
|
|
user_message=request.message,
|
|
intent=intent,
|
|
chunks=chunks,
|
|
history=history,
|
|
)
|
|
|
|
sources = _chunks_to_sources(chunks)
|
|
|
|
# Log request metadata
|
|
logger.info(
|
|
"Chat stream request received",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"message_length": len(request.message),
|
|
"mode": settings.llm_mode,
|
|
"intent": intent,
|
|
},
|
|
)
|
|
|
|
async def stream_response():
|
|
"""Async generator that yields SSE-formatted events."""
|
|
full_response = []
|
|
|
|
try:
|
|
full_prompt = f"{system_prompt}\n\n{user_content}"
|
|
async for chunk in adapter.generate_stream(conversation_id, full_prompt):
|
|
full_response.append(chunk)
|
|
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
|
|
yield f"data: {event.model_dump_json()}\n\n"
|
|
|
|
# Store conversation turn
|
|
response_text = "".join(full_response)
|
|
await memory.store_turn(
|
|
conversation_id=conversation_id,
|
|
user_message=request.message,
|
|
assistant_message=response_text,
|
|
sources=[s.model_dump() for s in sources] if sources else None,
|
|
)
|
|
|
|
# Send completion event with sources
|
|
done_event = StreamDoneEvent(
|
|
conversation_id=conversation_id,
|
|
mode=settings.llm_mode,
|
|
intent=intent,
|
|
sources=sources,
|
|
)
|
|
yield f"data: {done_event.model_dump_json()}\n\n"
|
|
|
|
logger.info(
|
|
"Chat stream completed",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"mode": settings.llm_mode,
|
|
"intent": intent,
|
|
"sources_count": len(sources),
|
|
},
|
|
)
|
|
|
|
except LLMError as e:
|
|
logger.error(
|
|
"LLM streaming failed",
|
|
extra={
|
|
"conversation_id": conversation_id,
|
|
"error_type": type(e).__name__,
|
|
"error_message": e.message,
|
|
},
|
|
)
|
|
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
|
|
yield f"data: {error_event.model_dump_json()}\n\n"
|
|
|
|
return StreamingResponse(
|
|
stream_response(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
|