"""LLM adapter implementations with FastAPI dependency injection support.""" from abc import ABC, abstractmethod from functools import lru_cache from typing import Annotated import httpx from fastapi import Depends from openai import ( AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError, ) from app.config import Settings, settings from app.llm.exceptions import ( LLMError, LLMAuthenticationError, LLMRateLimitError, LLMConnectionError, LLMConfigurationError, LLMResponseError, ) class LLMAdapter(ABC): """Abstract base class for LLM adapters.""" @abstractmethod async def generate(self, conversation_id: str, message: str) -> str: """Generate a response for the given message. Args: conversation_id: The conversation identifier message: The user's message Returns: The generated response string Raises: LLMError: If generation fails for any reason """ pass class LocalAdapter(LLMAdapter): """Local stub adapter for development and testing.""" async def generate(self, conversation_id: str, message: str) -> str: """Return a stub response echoing the user message. This is a placeholder that will be replaced with a real local model. """ return ( f"[LOCAL STUB MODE] Acknowledged your message. " f"You said: \"{message[:100]}{'...' if len(message) > 100 else ''}\". " f"This is a stub response - local model not yet implemented." ) class RemoteAdapter(LLMAdapter): """Remote adapter that calls an external LLM service via HTTP.""" def __init__(self, url: str, token: str | None = None, timeout: float = 30.0): """Initialize the remote adapter. Args: url: The remote LLM service URL token: Optional bearer token for authentication timeout: Request timeout in seconds """ self.url = url self.token = token self.timeout = timeout async def generate(self, conversation_id: str, message: str) -> str: """Call the remote LLM service to generate a response. Raises: LLMConnectionError: If connection fails or times out LLMAuthenticationError: If authentication fails LLMRateLimitError: If rate limit is exceeded LLMResponseError: If response is invalid LLMError: For other HTTP errors """ headers = {"Content-Type": "application/json"} if self.token: headers["Authorization"] = f"Bearer {self.token}" payload = { "conversation_id": conversation_id, "message": message, } try: async with httpx.AsyncClient(timeout=self.timeout) as client: response = await client.post(self.url, json=payload, headers=headers) if response.status_code == 401: raise LLMAuthenticationError("Remote LLM authentication failed") if response.status_code == 429: raise LLMRateLimitError("Remote LLM rate limit exceeded") if response.status_code != 200: raise LLMError( f"Remote LLM returned status {response.status_code}: " f"{response.text[:200] if response.text else 'No response body'}", status_code=response.status_code if 400 <= response.status_code < 600 else 502, ) try: data = response.json() except ValueError: raise LLMResponseError("Remote LLM returned invalid JSON response") if "response" not in data: raise LLMResponseError("Remote LLM response missing 'response' field") return data["response"] except httpx.TimeoutException: raise LLMConnectionError( f"Remote LLM request timed out after {self.timeout} seconds" ) except httpx.ConnectError: raise LLMConnectionError(f"Could not connect to remote LLM at {self.url}") except httpx.RequestError as e: raise LLMConnectionError(f"Remote LLM request failed: {str(e)}") class OpenAIAdapter(LLMAdapter): """OpenAI API adapter using the official SDK with native async support.""" def __init__(self, api_key: str, model: str = "gpt-4o-mini"): """Initialize the OpenAI adapter. Args: api_key: OpenAI API key model: Model identifier (e.g., "gpt-4o-mini", "gpt-4o") """ self.client = AsyncOpenAI(api_key=api_key) self.model = model async def generate(self, conversation_id: str, message: str) -> str: """Generate a response using the OpenAI API. Args: conversation_id: The conversation identifier (for future use with context) message: The user's message Returns: The generated response string Raises: LLMAuthenticationError: If API key is invalid LLMRateLimitError: If rate limit is exceeded LLMConnectionError: If connection fails LLMResponseError: If response content is empty LLMError: For other API errors """ try: response = await self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": message}], ) content = response.choices[0].message.content if content is None: raise LLMResponseError("OpenAI returned empty response content") return content except AuthenticationError as e: raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}") except RateLimitError as e: raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}") except APIConnectionError as e: raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}") except APIError as e: raise LLMError( f"OpenAI API error: {e.message}", status_code=e.status_code or 500 ) # --- Dependency Injection Support --- @lru_cache() def get_settings() -> Settings: """Get cached settings instance for dependency injection.""" return settings def get_adapter(settings: Annotated[Settings, Depends(get_settings)]) -> LLMAdapter: """Factory function to create the appropriate adapter based on configuration. This function is designed for use with FastAPI's Depends() system. Args: settings: Application settings (injected by FastAPI) Returns: An LLMAdapter instance based on the LLM_MODE setting Raises: LLMConfigurationError: If configuration is invalid for the selected mode """ if settings.llm_mode == "openai": if not settings.openai_api_key: raise LLMConfigurationError( "OPENAI_API_KEY must be set when LLM_MODE is 'openai'" ) return OpenAIAdapter( api_key=settings.openai_api_key, model=settings.openai_model, ) if settings.llm_mode == "remote": if not settings.llm_remote_url: raise LLMConfigurationError( "LLM_REMOTE_URL must be set when LLM_MODE is 'remote'" ) return RemoteAdapter( url=settings.llm_remote_url, token=settings.llm_remote_token or None, ) return LocalAdapter() # Type alias for clean dependency injection in endpoints AdapterDependency = Annotated[LLMAdapter, Depends(get_adapter)]