diff --git a/.env.example b/.env.example index af31b1c..37e074f 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,9 @@ # LLM Mode: "local", "remote", "openai", or "asksage" LLM_MODE=local +# CORS Configuration (comma-separated origins for frontend access) +CORS_ORIGINS=http://localhost:3000 + # Remote LLM Configuration (required if LLM_MODE=remote) LLM_REMOTE_URL=https://your-llm-service.com/generate LLM_REMOTE_TOKEN= diff --git a/app/config.py b/app/config.py index 9d76872..e1af224 100644 --- a/app/config.py +++ b/app/config.py @@ -19,6 +19,14 @@ class Settings(BaseSettings): asksage_api_key: str = "" asksage_model: str = "gpt-4o" + # CORS configuration + cors_origins: str = "http://localhost:3000" + + @property + def cors_origins_list(self) -> list[str]: + """Parse comma-separated CORS origins into a list.""" + return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()] + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/app/llm/adapter.py b/app/llm/adapter.py index c1a3613..596d3dd 100644 --- a/app/llm/adapter.py +++ b/app/llm/adapter.py @@ -3,6 +3,7 @@ import asyncio from abc import ABC, abstractmethod from functools import lru_cache +from collections.abc import AsyncIterator from typing import Annotated import httpx @@ -46,6 +47,27 @@ class LLMAdapter(ABC): """ pass + async def generate_stream( + self, conversation_id: str, message: str + ) -> AsyncIterator[str]: + """Stream a response for the given message. + + Default implementation yields the full response as a single chunk. + Subclasses can override this to provide true streaming. + + Args: + conversation_id: The conversation identifier + message: The user's message + + Yields: + Response content chunks + + Raises: + LLMError: If generation fails for any reason + """ + response = await self.generate(conversation_id, message) + yield response + class LocalAdapter(LLMAdapter): """Local stub adapter for development and testing.""" @@ -183,6 +205,46 @@ class OpenAIAdapter(LLMAdapter): f"OpenAI API error: {e.message}", status_code=e.status_code or 500 ) + async def generate_stream( + self, conversation_id: str, message: str + ) -> AsyncIterator[str]: + """Stream a response using the OpenAI API. + + Args: + conversation_id: The conversation identifier (for future use with context) + message: The user's message + + Yields: + Response content chunks + + Raises: + LLMAuthenticationError: If API key is invalid + LLMRateLimitError: If rate limit is exceeded + LLMConnectionError: If connection fails + LLMError: For other API errors + """ + try: + stream = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": message}], + stream=True, + ) + + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + except AuthenticationError as e: + raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}") + except RateLimitError as e: + raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}") + except APIConnectionError as e: + raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}") + except APIError as e: + raise LLMError( + f"OpenAI API error: {e.message}", status_code=e.status_code or 500 + ) + class AskSageAdapter(LLMAdapter): """AskSage API adapter using the official asksageclient SDK.""" diff --git a/app/main.py b/app/main.py index 3ccdaf8..cae6f15 100644 --- a/app/main.py +++ b/app/main.py @@ -2,10 +2,19 @@ import logging import uuid from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse from app.config import settings, MAX_MESSAGE_LENGTH from app.llm import AdapterDependency, LLMError, llm_exception_to_http -from app.schemas import ChatRequest, ChatResponse, HealthResponse +from app.schemas import ( + ChatRequest, + ChatResponse, + HealthResponse, + StreamChunkEvent, + StreamDoneEvent, + StreamErrorEvent, +) # Configure logging logging.basicConfig( @@ -21,6 +30,15 @@ app = FastAPI( version="0.1.0", ) +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins_list, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + @app.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: @@ -89,6 +107,79 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse ) +@app.post("/chat/stream") +async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse: + """Stream a chat response through the LLM adapter using Server-Sent Events. + + - Validates message length + - Generates conversation_id if not provided + - Routes to appropriate LLM adapter based on LLM_MODE + - Returns streaming response with SSE format + """ + # Validate message length + if len(request.message) > MAX_MESSAGE_LENGTH: + raise HTTPException( + status_code=400, + detail=f"Message exceeds maximum length of {MAX_MESSAGE_LENGTH:,} characters. " + f"Your message has {len(request.message):,} characters.", + ) + + # Generate or use provided conversation_id + conversation_id = request.conversation_id or str(uuid.uuid4()) + + # Log request metadata (not content) + logger.info( + "Chat stream request received", + extra={ + "conversation_id": conversation_id, + "message_length": len(request.message), + "mode": settings.llm_mode, + }, + ) + + async def stream_response(): + """Async generator that yields SSE-formatted events.""" + try: + async for chunk in adapter.generate_stream(conversation_id, request.message): + event = StreamChunkEvent(content=chunk, conversation_id=conversation_id) + yield f"data: {event.model_dump_json()}\n\n" + + # Send completion event + done_event = StreamDoneEvent( + conversation_id=conversation_id, mode=settings.llm_mode + ) + yield f"data: {done_event.model_dump_json()}\n\n" + + logger.info( + "Chat stream completed", + extra={ + "conversation_id": conversation_id, + "mode": settings.llm_mode, + }, + ) + + except LLMError as e: + logger.error( + "LLM streaming failed", + extra={ + "conversation_id": conversation_id, + "error_type": type(e).__name__, + "error_message": e.message, + }, + ) + error_event = StreamErrorEvent(message=e.message, code=e.status_code) + yield f"data: {error_event.model_dump_json()}\n\n" + + return StreamingResponse( + stream_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + + if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True) diff --git a/app/schemas.py b/app/schemas.py index 1bf1b98..f645920 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -31,3 +31,30 @@ class ErrorResponse(BaseModel): """Standard error response model.""" detail: str = Field(..., description="Error description") + + +# --- SSE Streaming Event Models --- + + +class StreamChunkEvent(BaseModel): + """SSE event for content chunks during streaming.""" + + type: Literal["chunk"] = "chunk" + content: str = Field(..., description="Content chunk from the LLM") + conversation_id: str = Field(..., description="Conversation ID") + + +class StreamDoneEvent(BaseModel): + """SSE event signaling completion of streaming.""" + + type: Literal["done"] = "done" + conversation_id: str = Field(..., description="Conversation ID") + mode: Literal["local", "remote", "openai", "asksage"] = Field(..., description="Which adapter was used") + + +class StreamErrorEvent(BaseModel): + """SSE event for errors during streaming.""" + + type: Literal["error"] = "error" + message: str = Field(..., description="Error message") + code: int = Field(default=500, description="HTTP status code")