Merge pull request #4 from dannyjosephgarcia/WOOL-18

feat: add CORS middleware and SSE streaming endpoint
This commit is contained in:
Danny Garcia 2026-01-16 12:44:26 -06:00 committed by GitHub
commit e3c4680108
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 192 additions and 1 deletions

View File

@ -1,6 +1,9 @@
# LLM Mode: "local", "remote", "openai", or "asksage"
LLM_MODE=local
# CORS Configuration (comma-separated origins for frontend access)
CORS_ORIGINS=http://localhost:3000
# Remote LLM Configuration (required if LLM_MODE=remote)
LLM_REMOTE_URL=https://your-llm-service.com/generate
LLM_REMOTE_TOKEN=

View File

@ -19,6 +19,14 @@ class Settings(BaseSettings):
asksage_api_key: str = ""
asksage_model: str = "gpt-4o"
# CORS configuration
cors_origins: str = "http://localhost:3000"
@property
def cors_origins_list(self) -> list[str]:
"""Parse comma-separated CORS origins into a list."""
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"

View File

@ -3,6 +3,7 @@
import asyncio
from abc import ABC, abstractmethod
from functools import lru_cache
from collections.abc import AsyncIterator
from typing import Annotated
import httpx
@ -46,6 +47,27 @@ class LLMAdapter(ABC):
"""
pass
async def generate_stream(
self, conversation_id: str, message: str
) -> AsyncIterator[str]:
"""Stream a response for the given message.
Default implementation yields the full response as a single chunk.
Subclasses can override this to provide true streaming.
Args:
conversation_id: The conversation identifier
message: The user's message
Yields:
Response content chunks
Raises:
LLMError: If generation fails for any reason
"""
response = await self.generate(conversation_id, message)
yield response
class LocalAdapter(LLMAdapter):
"""Local stub adapter for development and testing."""
@ -183,6 +205,46 @@ class OpenAIAdapter(LLMAdapter):
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
)
async def generate_stream(
self, conversation_id: str, message: str
) -> AsyncIterator[str]:
"""Stream a response using the OpenAI API.
Args:
conversation_id: The conversation identifier (for future use with context)
message: The user's message
Yields:
Response content chunks
Raises:
LLMAuthenticationError: If API key is invalid
LLMRateLimitError: If rate limit is exceeded
LLMConnectionError: If connection fails
LLMError: For other API errors
"""
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": message}],
stream=True,
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
except AuthenticationError as e:
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
except RateLimitError as e:
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
except APIConnectionError as e:
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
except APIError as e:
raise LLMError(
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
)
class AskSageAdapter(LLMAdapter):
"""AskSage API adapter using the official asksageclient SDK."""

View File

@ -2,10 +2,19 @@ import logging
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from app.config import settings, MAX_MESSAGE_LENGTH
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.schemas import ChatRequest, ChatResponse, HealthResponse
from app.schemas import (
ChatRequest,
ChatResponse,
HealthResponse,
StreamChunkEvent,
StreamDoneEvent,
StreamErrorEvent,
)
# Configure logging
logging.basicConfig(
@ -21,6 +30,15 @@ app = FastAPI(
version="0.1.0",
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
@ -89,6 +107,79 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse
)
@app.post("/chat/stream")
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse:
"""Stream a chat response through the LLM adapter using Server-Sent Events.
- Validates message length
- Generates conversation_id if not provided
- Routes to appropriate LLM adapter based on LLM_MODE
- Returns streaming response with SSE format
"""
# Validate message length
if len(request.message) > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
f"Your message has {len(request.message):,} characters.",
)
# Generate or use provided conversation_id
conversation_id = request.conversation_id or str(uuid.uuid4())
# Log request metadata (not content)
logger.info(
"Chat stream request received",
extra={
"conversation_id": conversation_id,
"message_length": len(request.message),
"mode": settings.llm_mode,
},
)
async def stream_response():
"""Async generator that yields SSE-formatted events."""
try:
async for chunk in adapter.generate_stream(conversation_id, request.message):
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
yield f"data: {event.model_dump_json()}\n\n"
# Send completion event
done_event = StreamDoneEvent(
conversation_id=conversation_id, mode=settings.llm_mode
)
yield f"data: {done_event.model_dump_json()}\n\n"
logger.info(
"Chat stream completed",
extra={
"conversation_id": conversation_id,
"mode": settings.llm_mode,
},
)
except LLMError as e:
logger.error(
"LLM streaming failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
yield f"data: {error_event.model_dump_json()}\n\n"
return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)

View File

@ -31,3 +31,30 @@ class ErrorResponse(BaseModel):
"""Standard error response model."""
detail: str = Field(..., description="Error description")
# --- SSE Streaming Event Models ---
class StreamChunkEvent(BaseModel):
"""SSE event for content chunks during streaming."""
type: Literal["chunk"] = "chunk"
content: str = Field(..., description="Content chunk from the LLM")
conversation_id: str = Field(..., description="Conversation ID")
class StreamDoneEvent(BaseModel):
"""SSE event signaling completion of streaming."""
type: Literal["done"] = "done"
conversation_id: str = Field(..., description="Conversation ID")
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
class StreamErrorEvent(BaseModel):
"""SSE event for errors during streaming."""
type: Literal["error"] = "error"
message: str = Field(..., description="Error message")
code: int = Field(default=500, description="HTTP status code")