feat: add CORS middleware and SSE streaming endpoint
Add CORS support for frontend development with configurable origins via CORS_ORIGINS environment variable. Add /chat/stream endpoint for Server-Sent Events streaming with true streaming support for OpenAI adapter and fallback single-chunk behavior for other adapters. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f497fde153
commit
6c1cf0655a
|
|
@ -1,6 +1,9 @@
|
|||
# LLM Mode: "local", "remote", "openai", or "asksage"
|
||||
LLM_MODE=local
|
||||
|
||||
# CORS Configuration (comma-separated origins for frontend access)
|
||||
CORS_ORIGINS=http://localhost:3000
|
||||
|
||||
# Remote LLM Configuration (required if LLM_MODE=remote)
|
||||
LLM_REMOTE_URL=https://your-llm-service.com/generate
|
||||
LLM_REMOTE_TOKEN=
|
||||
|
|
|
|||
|
|
@ -19,6 +19,14 @@ class Settings(BaseSettings):
|
|||
asksage_api_key: str = ""
|
||||
asksage_model: str = "gpt-4o"
|
||||
|
||||
# CORS configuration
|
||||
cors_origins: str = "http://localhost:3000"
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> list[str]:
|
||||
"""Parse comma-separated CORS origins into a list."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
|
|
@ -46,6 +47,27 @@ class LLMAdapter(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
async def generate_stream(
|
||||
self, conversation_id: str, message: str
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream a response for the given message.
|
||||
|
||||
Default implementation yields the full response as a single chunk.
|
||||
Subclasses can override this to provide true streaming.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation identifier
|
||||
message: The user's message
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
|
||||
Raises:
|
||||
LLMError: If generation fails for any reason
|
||||
"""
|
||||
response = await self.generate(conversation_id, message)
|
||||
yield response
|
||||
|
||||
|
||||
class LocalAdapter(LLMAdapter):
|
||||
"""Local stub adapter for development and testing."""
|
||||
|
|
@ -183,6 +205,46 @@ class OpenAIAdapter(LLMAdapter):
|
|||
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self, conversation_id: str, message: str
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream a response using the OpenAI API.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation identifier (for future use with context)
|
||||
message: The user's message
|
||||
|
||||
Yields:
|
||||
Response content chunks
|
||||
|
||||
Raises:
|
||||
LLMAuthenticationError: If API key is invalid
|
||||
LLMRateLimitError: If rate limit is exceeded
|
||||
LLMConnectionError: If connection fails
|
||||
LLMError: For other API errors
|
||||
"""
|
||||
try:
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": message}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
except AuthenticationError as e:
|
||||
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
|
||||
except RateLimitError as e:
|
||||
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
|
||||
except APIConnectionError as e:
|
||||
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
|
||||
except APIError as e:
|
||||
raise LLMError(
|
||||
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
||||
)
|
||||
|
||||
|
||||
class AskSageAdapter(LLMAdapter):
|
||||
"""AskSage API adapter using the official asksageclient SDK."""
|
||||
|
|
|
|||
93
app/main.py
93
app/main.py
|
|
@ -2,10 +2,19 @@ import logging
|
|||
import uuid
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config import settings, MAX_MESSAGE_LENGTH
|
||||
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
||||
from app.schemas import ChatRequest, ChatResponse, HealthResponse
|
||||
from app.schemas import (
|
||||
ChatRequest,
|
||||
ChatResponse,
|
||||
HealthResponse,
|
||||
StreamChunkEvent,
|
||||
StreamDoneEvent,
|
||||
StreamErrorEvent,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
|
@ -21,6 +30,15 @@ app = FastAPI(
|
|||
version="0.1.0",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
|
|
@ -89,6 +107,79 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse
|
|||
)
|
||||
|
||||
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse:
|
||||
"""Stream a chat response through the LLM adapter using Server-Sent Events.
|
||||
|
||||
- Validates message length
|
||||
- Generates conversation_id if not provided
|
||||
- Routes to appropriate LLM adapter based on LLM_MODE
|
||||
- Returns streaming response with SSE format
|
||||
"""
|
||||
# Validate message length
|
||||
if len(request.message) > MAX_MESSAGE_LENGTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
|
||||
f"Your message has {len(request.message):,} characters.",
|
||||
)
|
||||
|
||||
# Generate or use provided conversation_id
|
||||
conversation_id = request.conversation_id or str(uuid.uuid4())
|
||||
|
||||
# Log request metadata (not content)
|
||||
logger.info(
|
||||
"Chat stream request received",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"message_length": len(request.message),
|
||||
"mode": settings.llm_mode,
|
||||
},
|
||||
)
|
||||
|
||||
async def stream_response():
|
||||
"""Async generator that yields SSE-formatted events."""
|
||||
try:
|
||||
async for chunk in adapter.generate_stream(conversation_id, request.message):
|
||||
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
|
||||
yield f"data: {event.model_dump_json()}\n\n"
|
||||
|
||||
# Send completion event
|
||||
done_event = StreamDoneEvent(
|
||||
conversation_id=conversation_id, mode=settings.llm_mode
|
||||
)
|
||||
yield f"data: {done_event.model_dump_json()}\n\n"
|
||||
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"mode": settings.llm_mode,
|
||||
},
|
||||
)
|
||||
|
||||
except LLMError as e:
|
||||
logger.error(
|
||||
"LLM streaming failed",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": e.message,
|
||||
},
|
||||
)
|
||||
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
|
||||
yield f"data: {error_event.model_dump_json()}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
|
||||
|
|
|
|||
|
|
@ -31,3 +31,30 @@ class ErrorResponse(BaseModel):
|
|||
"""Standard error response model."""
|
||||
|
||||
detail: str = Field(..., description="Error description")
|
||||
|
||||
|
||||
# --- SSE Streaming Event Models ---
|
||||
|
||||
|
||||
class StreamChunkEvent(BaseModel):
|
||||
"""SSE event for content chunks during streaming."""
|
||||
|
||||
type: Literal["chunk"] = "chunk"
|
||||
content: str = Field(..., description="Content chunk from the LLM")
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
|
||||
|
||||
class StreamDoneEvent(BaseModel):
|
||||
"""SSE event signaling completion of streaming."""
|
||||
|
||||
type: Literal["done"] = "done"
|
||||
conversation_id: str = Field(..., description="Conversation ID")
|
||||
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
||||
|
||||
|
||||
class StreamErrorEvent(BaseModel):
|
||||
"""SSE event for errors during streaming."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
message: str = Field(..., description="Error message")
|
||||
code: int = Field(default=500, description="HTTP status code")
|
||||
|
|
|
|||
Loading…
Reference in New Issue