tyndale-ai-service/app/main.py

186 lines
5.6 KiB
Python

import logging
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from app.config import settings, MAX_MESSAGE_LENGTH
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.schemas import (
ChatRequest,
ChatResponse,
HealthResponse,
StreamChunkEvent,
StreamDoneEvent,
StreamErrorEvent,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="Tyndale AI Service",
description="LLM Chat Service for algorithmic trading support",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Health check endpoint."""
return HealthResponse(status="ok")
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse:
"""Process a chat message through the LLM adapter.
- Validates message length
- Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content)
logger.info(
"Chat request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
},
)
# Generate response with exception handling
try:
response_text = await adapter.generate(conversation_id, request.message)
except LLMError as e:
logger.error(
"LLM generation failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
raise llm_exception_to_http(e)
# Log response metadata
logger.info(
"Chat response generated",
extra={
"conversation_id": conversation_id,
"response_length": len(response_text),
"mode": settings.llm_mode,
},
)
return ChatResponse(
conversation_id=conversation_id,
response=response_text,
mode=settings.llm_mode,
sources=[],
)
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse:
"""Stream a chat response through the LLM adapter using Server-Sent Events.
- Validates message length
- Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE
- Returns streaming response with SSE format
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content)
logger.info(
"Chat stream request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
},
)
async def stream_response():
"""Async generator that yields SSE-formatted events."""
try:
async for chunk in adapter.generate_stream(conversation_id, request.message):
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
yield f"data: {event.model_dump_json()}\n\n"
# Send completion event
done_event = StreamDoneEvent(
conversation_id=conversation_id, mode=settings.llm_mode
)
yield f"data: {done_event.model_dump_json()}\n\n"
logger.info(
"Chat stream completed",
extra={
"conversation_id": conversation_id,
"mode": settings.llm_mode,
},
)
except LLMError as e:
logger.error(
"LLM streaming failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)