feat: add OpenAI integration with dependency injection support

- Add OpenAIAdapter class using official OpenAI SDK with async support
- Create custom exception hierarchy for LLM errors (authentication,
  rate limit, connection, configuration, response errors)
- Refactor adapter factory to use FastAPI Depends() for dependency injection
- Update configuration to support 'openai' mode with API key and model settings
- Add proper HTTP error mapping for all LLM exception types
- Update Dockerfile with default OPENAI_MODEL environment variable
- Update .env.example with OpenAI configuration options
This commit is contained in:
Danny
2026-01-13 15:17:44 -06:00
parent d5525b12b2
commit 3324b6ac12
11 changed files with 434 additions and 22 deletions
+5 -1
View File
@@ -6,10 +6,14 @@ from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application configuration loaded from environment variables."""
llm_mode: Literal["local", "remote"] = "local"
llm_mode: Literal["local", "remote", "openai"] = "local"
llm_remote_url: str = ""
llm_remote_token: str = ""
# OpenAI configuration
openai_api_key: str = ""
openai_model: str = "gpt-4o-mini"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
+40 -1
View File
@@ -1 +1,40 @@
# LLM adapters
"""LLM adapters and utilities."""
from app.llm.adapter import (
LLMAdapter,
LocalAdapter,
RemoteAdapter,
OpenAIAdapter,
get_adapter,
get_settings,
AdapterDependency,
)
from app.llm.exceptions import (
LLMError,
LLMAuthenticationError,
LLMRateLimitError,
LLMConnectionError,
LLMConfigurationError,
LLMResponseError,
llm_exception_to_http,
)
__all__ = [
# Adapters
"LLMAdapter",
"LocalAdapter",
"RemoteAdapter",
"OpenAIAdapter",
# DI support
"get_adapter",
"get_settings",
"AdapterDependency",
# Exceptions
"LLMError",
"LLMAuthenticationError",
"LLMRateLimitError",
"LLMConnectionError",
"LLMConfigurationError",
"LLMResponseError",
"llm_exception_to_http",
]
+134 -12
View File
@@ -1,8 +1,28 @@
"""LLM adapter implementations with FastAPI dependency injection support."""
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Annotated
import httpx
from fastapi import Depends
from openai import (
AsyncOpenAI,
AuthenticationError,
RateLimitError,
APIConnectionError,
APIError,
)
from app.config import settings
from app.config import Settings, settings
from app.llm.exceptions import (
LLMError,
LLMAuthenticationError,
LLMRateLimitError,
LLMConnectionError,
LLMConfigurationError,
LLMResponseError,
)
class LLMAdapter(ABC):
@@ -18,6 +38,9 @@ class LLMAdapter(ABC):
Returns:
The generated response string
Raises:
LLMError: If generation fails for any reason
"""
pass
@@ -55,7 +78,12 @@ class RemoteAdapter(LLMAdapter):
async def generate(self, conversation_id: str, message: str) -> str:
"""Call the remote LLM service to generate a response.
Handles errors gracefully by returning informative error strings.
Raises:
LLMConnectionError: If connection fails or times out
LLMAuthenticationError: If authentication fails
LLMRateLimitError: If rate limit is exceeded
LLMResponseError: If response is invalid
LLMError: For other HTTP errors
"""
headers = {"Content-Type": "application/json"}
if self.token:
@@ -70,41 +98,135 @@ class RemoteAdapter(LLMAdapter):
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(self.url, json=payload, headers=headers)
if response.status_code == 401:
raise LLMAuthenticationError("Remote LLM authentication failed")
if response.status_code == 429:
raise LLMRateLimitError("Remote LLM rate limit exceeded")
if response.status_code != 200:
return (
f"[ERROR] Remote LLM returned status {response.status_code}: "
f"{response.text[:200] if response.text else 'No response body'}"
raise LLMError(
f"Remote LLM returned status {response.status_code}: "
f"{response.text[:200] if response.text else 'No response body'}",
status_code=response.status_code if 400 <= response.status_code < 600 else 502,
)
try:
data = response.json()
except ValueError:
return "[ERROR] Remote LLM returned invalid JSON response"
raise LLMResponseError("Remote LLM returned invalid JSON response")
if "response" not in data:
return "[ERROR] Remote LLM response missing 'response' field"
raise LLMResponseError("Remote LLM response missing 'response' field")
return data["response"]
except httpx.TimeoutException:
return f"[ERROR] Remote LLM request timed out after {self.timeout} seconds"
raise LLMConnectionError(
f"Remote LLM request timed out after {self.timeout} seconds"
)
except httpx.ConnectError:
return f"[ERROR] Could not connect to remote LLM at {self.url}"
raise LLMConnectionError(f"Could not connect to remote LLM at {self.url}")
except httpx.RequestError as e:
return f"[ERROR] Remote LLM request failed: {str(e)}"
raise LLMConnectionError(f"Remote LLM request failed: {str(e)}")
def get_adapter() -> LLMAdapter:
class OpenAIAdapter(LLMAdapter):
"""OpenAI API adapter using the official SDK with native async support."""
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
"""Initialize the OpenAI adapter.
Args:
api_key: OpenAI API key
model: Model identifier (e.g., "gpt-4o-mini", "gpt-4o")
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
async def generate(self, conversation_id: str, message: str) -> str:
"""Generate a response using the OpenAI API.
Args:
conversation_id: The conversation identifier (for future use with context)
message: The user's message
Returns:
The generated response string
Raises:
LLMAuthenticationError: If API key is invalid
LLMRateLimitError: If rate limit is exceeded
LLMConnectionError: If connection fails
LLMResponseError: If response content is empty
LLMError: For other API errors
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": message}],
)
content = response.choices[0].message.content
if content is None:
raise LLMResponseError("OpenAI returned empty response content")
return content
except AuthenticationError as e:
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
except RateLimitError as e:
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
except APIConnectionError as e:
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
except APIError as e:
raise LLMError(
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
)
# --- Dependency Injection Support ---
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance for dependency injection."""
return settings
def get_adapter(settings: Annotated[Settings, Depends(get_settings)]) -> LLMAdapter:
"""Factory function to create the appropriate adapter based on configuration.
This function is designed for use with FastAPI's Depends() system.
Args:
settings: Application settings (injected by FastAPI)
Returns:
An LLMAdapter instance based on the LLM_MODE setting
Raises:
LLMConfigurationError: If configuration is invalid for the selected mode
"""
if settings.llm_mode == "openai":
if not settings.openai_api_key:
raise LLMConfigurationError(
"OPENAI_API_KEY must be set when LLM_MODE is 'openai'"
)
return OpenAIAdapter(
api_key=settings.openai_api_key,
model=settings.openai_model,
)
if settings.llm_mode == "remote":
if not settings.llm_remote_url:
raise ValueError("LLM_REMOTE_URL must be set when LLM_MODE is 'remote'")
raise LLMConfigurationError(
"LLM_REMOTE_URL must be set when LLM_MODE is 'remote'"
)
return RemoteAdapter(
url=settings.llm_remote_url,
token=settings.llm_remote_token or None,
)
return LocalAdapter()
# Type alias for clean dependency injection in endpoints
AdapterDependency = Annotated[LLMAdapter, Depends(get_adapter)]
+59
View File
@@ -0,0 +1,59 @@
"""Custom exceptions for LLM adapters."""
from fastapi import HTTPException
class LLMError(Exception):
"""Base exception for all LLM-related errors."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(message)
class LLMAuthenticationError(LLMError):
"""Raised when API authentication fails."""
def __init__(self, message: str = "LLM authentication failed"):
super().__init__(message, status_code=401)
class LLMRateLimitError(LLMError):
"""Raised when rate limit is exceeded."""
def __init__(self, message: str = "LLM rate limit exceeded"):
super().__init__(message, status_code=429)
class LLMConnectionError(LLMError):
"""Raised when connection to LLM service fails."""
def __init__(self, message: str = "Could not connect to LLM service"):
super().__init__(message, status_code=503)
class LLMConfigurationError(LLMError):
"""Raised when LLM configuration is invalid."""
def __init__(self, message: str = "Invalid LLM configuration"):
super().__init__(message, status_code=500)
class LLMResponseError(LLMError):
"""Raised when LLM returns an invalid or unexpected response."""
def __init__(self, message: str = "Invalid response from LLM"):
super().__init__(message, status_code=502)
def llm_exception_to_http(exc: LLMError) -> HTTPException:
"""Convert an LLMError to a FastAPI HTTPException.
Args:
exc: The LLMError to convert
Returns:
An HTTPException with appropriate status code and detail
"""
return HTTPException(status_code=exc.status_code, detail=exc.message)
+20 -5
View File
@@ -4,7 +4,7 @@ import uuid
from fastapi import FastAPI, HTTPException
from app.config import settings, MAX_MESSAGE_LENGTH
from app.llm.adapter import get_adapter
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.schemas import ChatRequest, ChatResponse, HealthResponse
# Configure logging
@@ -29,7 +29,7 @@ async def health_check() -> HealthResponse:
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest) -> ChatResponse:
async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse:
"""Process a chat message through the LLM adapter.
- Validates message length
@@ -57,9 +57,19 @@ async def chat(request: ChatRequest) -> ChatResponse:
},
)
# Get adapter and generate response
adapter = get_adapter()
response_text = await adapter.generate(conversation_id, request.message)
# Generate response with exception handling
try:
response_text = await adapter.generate(conversation_id, request.message)
except LLMError as e:
logger.error(
"LLM generation failed",
extra={
"conversation_id": conversation_id,
"error_type": type(e).__name__,
"error_message": e.message,
},
)
raise llm_exception_to_http(e)
# Log response metadata
logger.info(
@@ -77,3 +87,8 @@ async def chat(request: ChatRequest) -> ChatResponse:
mode=settings.llm_mode,
sources=[],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
+1 -1
View File
@@ -17,7 +17,7 @@ class ChatResponse(BaseModel):
conversation_id: str = Field(..., description="Conversation ID (generated if not provided)")
response: str = Field(..., description="The LLM's response")
mode: Literal["local", "remote"] = Field(..., description="Which adapter was used")
mode: Literal["local", "remote", "openai"] = Field(..., description="Which adapter was used")
sources: list = Field(default_factory=list, description="Source references (empty for now)")