Merge pull request #4 from dannyjosephgarcia/WOOL-18
feat: add CORS middleware and SSE streaming endpoint
This commit is contained in:
commit
e3c4680108
|
|
@ -1,6 +1,9 @@
|
||||||
# LLM Mode: "local", "remote", "openai", or "asksage"
|
# LLM Mode: "local", "remote", "openai", or "asksage"
|
||||||
LLM_MODE=local
|
LLM_MODE=local
|
||||||
|
|
||||||
|
# CORS Configuration (comma-separated origins for frontend access)
|
||||||
|
CORS_ORIGINS=http://localhost:3000
|
||||||
|
|
||||||
# Remote LLM Configuration (required if LLM_MODE=remote)
|
# Remote LLM Configuration (required if LLM_MODE=remote)
|
||||||
LLM_REMOTE_URL=https://your-llm-service.com/generate
|
LLM_REMOTE_URL=https://your-llm-service.com/generate
|
||||||
LLM_REMOTE_TOKEN=
|
LLM_REMOTE_TOKEN=
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,14 @@ class Settings(BaseSettings):
|
||||||
asksage_api_key: str = ""
|
asksage_api_key: str = ""
|
||||||
asksage_model: str = "gpt-4o"
|
asksage_model: str = "gpt-4o"
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
|
cors_origins: str = "http://localhost:3000"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins_list(self) -> list[str]:
|
||||||
|
"""Parse comma-separated CORS origins into a list."""
|
||||||
|
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = "utf-8"
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
@ -46,6 +47,27 @@ class LLMAdapter(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self, conversation_id: str, message: str
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream a response for the given message.
|
||||||
|
|
||||||
|
Default implementation yields the full response as a single chunk.
|
||||||
|
Subclasses can override this to provide true streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation identifier
|
||||||
|
message: The user's message
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Response content chunks
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMError: If generation fails for any reason
|
||||||
|
"""
|
||||||
|
response = await self.generate(conversation_id, message)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
class LocalAdapter(LLMAdapter):
|
class LocalAdapter(LLMAdapter):
|
||||||
"""Local stub adapter for development and testing."""
|
"""Local stub adapter for development and testing."""
|
||||||
|
|
@ -183,6 +205,46 @@ class OpenAIAdapter(LLMAdapter):
|
||||||
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self, conversation_id: str, message: str
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream a response using the OpenAI API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: The conversation identifier (for future use with context)
|
||||||
|
message: The user's message
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Response content chunks
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMAuthenticationError: If API key is invalid
|
||||||
|
LLMRateLimitError: If rate limit is exceeded
|
||||||
|
LLMConnectionError: If connection fails
|
||||||
|
LLMError: For other API errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stream = await self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[{"role": "user", "content": message}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in stream:
|
||||||
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
|
except AuthenticationError as e:
|
||||||
|
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
|
||||||
|
except RateLimitError as e:
|
||||||
|
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
|
||||||
|
except APIConnectionError as e:
|
||||||
|
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
|
||||||
|
except APIError as e:
|
||||||
|
raise LLMError(
|
||||||
|
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AskSageAdapter(LLMAdapter):
|
class AskSageAdapter(LLMAdapter):
|
||||||
"""AskSage API adapter using the official asksageclient SDK."""
|
"""AskSage API adapter using the official asksageclient SDK."""
|
||||||
|
|
|
||||||
93
app/main.py
93
app/main.py
|
|
@ -2,10 +2,19 @@ import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from app.config import settings, MAX_MESSAGE_LENGTH
|
from app.config import settings, MAX_MESSAGE_LENGTH
|
||||||
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
||||||
from app.schemas import ChatRequest, ChatResponse, HealthResponse
|
from app.schemas import (
|
||||||
|
ChatRequest,
|
||||||
|
ChatResponse,
|
||||||
|
HealthResponse,
|
||||||
|
StreamChunkEvent,
|
||||||
|
StreamDoneEvent,
|
||||||
|
StreamErrorEvent,
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|
@ -21,6 +30,15 @@ app = FastAPI(
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins_list,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", response_model=HealthResponse)
|
@app.get("/health", response_model=HealthResponse)
|
||||||
async def health_check() -> HealthResponse:
|
async def health_check() -> HealthResponse:
|
||||||
|
|
@ -89,6 +107,79 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/chat/stream")
|
||||||
|
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse:
|
||||||
|
"""Stream a chat response through the LLM adapter using Server-Sent Events.
|
||||||
|
|
||||||
|
- Validates message length
|
||||||
|
- Generates conversation_id if not provided
|
||||||
|
- Routes to appropriate LLM adapter based on LLM_MODE
|
||||||
|
- Returns streaming response with SSE format
|
||||||
|
"""
|
||||||
|
# Validate message length
|
||||||
|
if len(request.message) > MAX_MESSAGE_LENGTH:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. "
|
||||||
|
f"Your message has {len(request.message):,} characters.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate or use provided conversation_id
|
||||||
|
conversation_id = request.conversation_id or str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Log request metadata (not content)
|
||||||
|
logger.info(
|
||||||
|
"Chat stream request received",
|
||||||
|
extra={
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"message_length": len(request.message),
|
||||||
|
"mode": settings.llm_mode,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stream_response():
|
||||||
|
"""Async generator that yields SSE-formatted events."""
|
||||||
|
try:
|
||||||
|
async for chunk in adapter.generate_stream(conversation_id, request.message):
|
||||||
|
event = StreamChunkEvent(content=chunk, conversation_id=conversation_id)
|
||||||
|
yield f"data: {event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
# Send completion event
|
||||||
|
done_event = StreamDoneEvent(
|
||||||
|
conversation_id=conversation_id, mode=settings.llm_mode
|
||||||
|
)
|
||||||
|
yield f"data: {done_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Chat stream completed",
|
||||||
|
extra={
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"mode": settings.llm_mode,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except LLMError as e:
|
||||||
|
logger.error(
|
||||||
|
"LLM streaming failed",
|
||||||
|
extra={
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
"error_message": e.message,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
error_event = StreamErrorEvent(message=e.message, code=e.status_code)
|
||||||
|
yield f"data: {error_event.model_dump_json()}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
stream_response(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
|
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
|
||||||
|
|
|
||||||
|
|
@ -31,3 +31,30 @@ class ErrorResponse(BaseModel):
|
||||||
"""Standard error response model."""
|
"""Standard error response model."""
|
||||||
|
|
||||||
detail: str = Field(..., description="Error description")
|
detail: str = Field(..., description="Error description")
|
||||||
|
|
||||||
|
|
||||||
|
# --- SSE Streaming Event Models ---
|
||||||
|
|
||||||
|
|
||||||
|
class StreamChunkEvent(BaseModel):
|
||||||
|
"""SSE event for content chunks during streaming."""
|
||||||
|
|
||||||
|
type: Literal["chunk"] = "chunk"
|
||||||
|
content: str = Field(..., description="Content chunk from the LLM")
|
||||||
|
conversation_id: str = Field(..., description="Conversation ID")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamDoneEvent(BaseModel):
|
||||||
|
"""SSE event signaling completion of streaming."""
|
||||||
|
|
||||||
|
type: Literal["done"] = "done"
|
||||||
|
conversation_id: str = Field(..., description="Conversation ID")
|
||||||
|
mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamErrorEvent(BaseModel):
|
||||||
|
"""SSE event for errors during streaming."""
|
||||||
|
|
||||||
|
type: Literal["error"] = "error"
|
||||||
|
message: str = Field(..., description="Error message")
|
||||||
|
code: int = Field(default=500, description="HTTP status code")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue