tyndale-ai-service/app/main.py

195 lines
5.7 KiB
Python

import logging
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from app.auth import ServiceAuthDependency
from app.config import settings, MAX_MESSAGE_LENGTH
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.schemas import (
ChatRequest,
ChatResponse,
HealthResponse,
StreamChunkEvent,
StreamDoneEvent,
StreamErrorEvent,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="Tyndale AI Service",
description="LLM Chat Service for algorithmic trading support",
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Health check endpoint."""
return HealthResponse(status="ok")
@app.post("/chat", response_model=ChatResponse)
async def chat(
request: ChatRequest,
adapter: AdapterDependency,
_auth: ServiceAuthDependency,
) -> ChatResponse:
"""Process a chat message through the LLM adapter.
- Validates message length
- Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content)
logger.info(
"Chat request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
},
)
# Generate response with exception handling
try:
response_text = await adapter.generate(conversation_id, request.message)
except LLMError as e:
logger.error(
"LLM generation failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
raise llm_exception_to_http(e)
# Log response metadata
logger.info(
"Chat response generated",
extra={
"conversation_id": conversation_id,
"response_length": len(response_text),
"mode": settings.llm_mode,
},
)
return ChatResponse(
conversation_id=conversation_id,
response=response_text,
mode=settings.llm_mode,
sources=[],
)
@app.post("/chat/stream")
async def chat_stream(
request: ChatRequest,
adapter: AdapterDependency,
_auth: ServiceAuthDependency,
) -> StreamingResponse:
"""Stream a chat response through the LLM adapter using Server-Sent Events.
- Validates message length
- Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE
- Returns streaming response with SSE format
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content)
logger.info(
"Chat stream request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
},
)
async def stream_response():
"""Async generator that yields SSE-formatted events."""
try:
async for chunk in adapter.generate_stream(conversation_id, request.message):
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
yield f"data: {event.model_dump_json()}\n\n"
# Send completion event
done_event = StreamDoneEvent(
conversation_id=conversation_id, mode=settings.llm_mode
)
yield f"data: {done_event.model_dump_json()}\n\n"
logger.info(
"Chat stream completed",
extra={
"conversation_id": conversation_id,
"mode": settings.llm_mode,
},
)
except LLMError as e:
logger.error(
"LLM streaming failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)