feat: add OpenAI integration with dependency injection support
- Add OpenAIAdapter class using official OpenAI SDK with async support - Create custom exception hierarchy for LLM errors (authentication, rate limit, connection, configuration, response errors) - Refactor adapter factory to use FastAPI Depends() for dependency injection - Update configuration to support 'openai' mode with API key and model settings - Add proper HTTP error mapping for all LLM exception types - Update Dockerfile with default OPENAI_MODEL environment variable - Update .env.example with OpenAI configuration options
This commit is contained in:
+5
-1
@@ -6,10 +6,14 @@ from pydantic_settings import BaseSettings
|
||||
class Settings(BaseSettings):
|
||||
"""Application configuration loaded from environment variables."""
|
||||
|
||||
llm_mode: Literal["local", "remote"] = "local"
|
||||
llm_mode: Literal["local", "remote", "openai"] = "local"
|
||||
llm_remote_url: str = ""
|
||||
llm_remote_token: str = ""
|
||||
|
||||
# OpenAI configuration
|
||||
openai_api_key: str = ""
|
||||
openai_model: str = "gpt-4o-mini"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
+40
-1
@@ -1 +1,40 @@
|
||||
# LLM adapters
|
||||
"""LLM adapters and utilities."""
|
||||
|
||||
from app.llm.adapter import (
|
||||
LLMAdapter,
|
||||
LocalAdapter,
|
||||
RemoteAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
get_settings,
|
||||
AdapterDependency,
|
||||
)
|
||||
from app.llm.exceptions import (
|
||||
LLMError,
|
||||
LLMAuthenticationError,
|
||||
LLMRateLimitError,
|
||||
LLMConnectionError,
|
||||
LLMConfigurationError,
|
||||
LLMResponseError,
|
||||
llm_exception_to_http,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Adapters
|
||||
"LLMAdapter",
|
||||
"LocalAdapter",
|
||||
"RemoteAdapter",
|
||||
"OpenAIAdapter",
|
||||
# DI support
|
||||
"get_adapter",
|
||||
"get_settings",
|
||||
"AdapterDependency",
|
||||
# Exceptions
|
||||
"LLMError",
|
||||
"LLMAuthenticationError",
|
||||
"LLMRateLimitError",
|
||||
"LLMConnectionError",
|
||||
"LLMConfigurationError",
|
||||
"LLMResponseError",
|
||||
"llm_exception_to_http",
|
||||
]
|
||||
|
||||
+134
-12
@@ -1,8 +1,28 @@
|
||||
"""LLM adapter implementations with FastAPI dependency injection support."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
from fastapi import Depends
|
||||
from openai import (
|
||||
AsyncOpenAI,
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
)
|
||||
|
||||
from app.config import settings
|
||||
from app.config import Settings, settings
|
||||
from app.llm.exceptions import (
|
||||
LLMError,
|
||||
LLMAuthenticationError,
|
||||
LLMRateLimitError,
|
||||
LLMConnectionError,
|
||||
LLMConfigurationError,
|
||||
LLMResponseError,
|
||||
)
|
||||
|
||||
|
||||
class LLMAdapter(ABC):
|
||||
@@ -18,6 +38,9 @@ class LLMAdapter(ABC):
|
||||
|
||||
Returns:
|
||||
The generated response string
|
||||
|
||||
Raises:
|
||||
LLMError: If generation fails for any reason
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -55,7 +78,12 @@ class RemoteAdapter(LLMAdapter):
|
||||
async def generate(self, conversation_id: str, message: str) -> str:
|
||||
"""Call the remote LLM service to generate a response.
|
||||
|
||||
Handles errors gracefully by returning informative error strings.
|
||||
Raises:
|
||||
LLMConnectionError: If connection fails or times out
|
||||
LLMAuthenticationError: If authentication fails
|
||||
LLMRateLimitError: If rate limit is exceeded
|
||||
LLMResponseError: If response is invalid
|
||||
LLMError: For other HTTP errors
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.token:
|
||||
@@ -70,41 +98,135 @@ class RemoteAdapter(LLMAdapter):
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(self.url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise LLMAuthenticationError("Remote LLM authentication failed")
|
||||
if response.status_code == 429:
|
||||
raise LLMRateLimitError("Remote LLM rate limit exceeded")
|
||||
if response.status_code != 200:
|
||||
return (
|
||||
f"[ERROR] Remote LLM returned status {response.status_code}: "
|
||||
f"{response.text[:200] if response.text else 'No response body'}"
|
||||
raise LLMError(
|
||||
f"Remote LLM returned status {response.status_code}: "
|
||||
f"{response.text[:200] if response.text else 'No response body'}",
|
||||
status_code=response.status_code if 400 <= response.status_code < 600 else 502,
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
return "[ERROR] Remote LLM returned invalid JSON response"
|
||||
raise LLMResponseError("Remote LLM returned invalid JSON response")
|
||||
|
||||
if "response" not in data:
|
||||
return "[ERROR] Remote LLM response missing 'response' field"
|
||||
raise LLMResponseError("Remote LLM response missing 'response' field")
|
||||
|
||||
return data["response"]
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return f"[ERROR] Remote LLM request timed out after {self.timeout} seconds"
|
||||
raise LLMConnectionError(
|
||||
f"Remote LLM request timed out after {self.timeout} seconds"
|
||||
)
|
||||
except httpx.ConnectError:
|
||||
return f"[ERROR] Could not connect to remote LLM at {self.url}"
|
||||
raise LLMConnectionError(f"Could not connect to remote LLM at {self.url}")
|
||||
except httpx.RequestError as e:
|
||||
return f"[ERROR] Remote LLM request failed: {str(e)}"
|
||||
raise LLMConnectionError(f"Remote LLM request failed: {str(e)}")
|
||||
|
||||
|
||||
def get_adapter() -> LLMAdapter:
|
||||
class OpenAIAdapter(LLMAdapter):
|
||||
"""OpenAI API adapter using the official SDK with native async support."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
|
||||
"""Initialize the OpenAI adapter.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model identifier (e.g., "gpt-4o-mini", "gpt-4o")
|
||||
"""
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.model = model
|
||||
|
||||
async def generate(self, conversation_id: str, message: str) -> str:
|
||||
"""Generate a response using the OpenAI API.
|
||||
|
||||
Args:
|
||||
conversation_id: The conversation identifier (for future use with context)
|
||||
message: The user's message
|
||||
|
||||
Returns:
|
||||
The generated response string
|
||||
|
||||
Raises:
|
||||
LLMAuthenticationError: If API key is invalid
|
||||
LLMRateLimitError: If rate limit is exceeded
|
||||
LLMConnectionError: If connection fails
|
||||
LLMResponseError: If response content is empty
|
||||
LLMError: For other API errors
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": message}],
|
||||
)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
if content is None:
|
||||
raise LLMResponseError("OpenAI returned empty response content")
|
||||
return content
|
||||
|
||||
except AuthenticationError as e:
|
||||
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
|
||||
except RateLimitError as e:
|
||||
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
|
||||
except APIConnectionError as e:
|
||||
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
|
||||
except APIError as e:
|
||||
raise LLMError(
|
||||
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
||||
)
|
||||
|
||||
|
||||
# --- Dependency Injection Support ---
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance for dependency injection."""
|
||||
return settings
|
||||
|
||||
|
||||
def get_adapter(settings: Annotated[Settings, Depends(get_settings)]) -> LLMAdapter:
|
||||
"""Factory function to create the appropriate adapter based on configuration.
|
||||
|
||||
This function is designed for use with FastAPI's Depends() system.
|
||||
|
||||
Args:
|
||||
settings: Application settings (injected by FastAPI)
|
||||
|
||||
Returns:
|
||||
An LLMAdapter instance based on the LLM_MODE setting
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If configuration is invalid for the selected mode
|
||||
"""
|
||||
if settings.llm_mode == "openai":
|
||||
if not settings.openai_api_key:
|
||||
raise LLMConfigurationError(
|
||||
"OPENAI_API_KEY must be set when LLM_MODE is 'openai'"
|
||||
)
|
||||
return OpenAIAdapter(
|
||||
api_key=settings.openai_api_key,
|
||||
model=settings.openai_model,
|
||||
)
|
||||
|
||||
if settings.llm_mode == "remote":
|
||||
if not settings.llm_remote_url:
|
||||
raise ValueError("LLM_REMOTE_URL must be set when LLM_MODE is 'remote'")
|
||||
raise LLMConfigurationError(
|
||||
"LLM_REMOTE_URL must be set when LLM_MODE is 'remote'"
|
||||
)
|
||||
return RemoteAdapter(
|
||||
url=settings.llm_remote_url,
|
||||
token=settings.llm_remote_token or None,
|
||||
)
|
||||
|
||||
return LocalAdapter()
|
||||
|
||||
|
||||
# Type alias for clean dependency injection in endpoints
|
||||
AdapterDependency = Annotated[LLMAdapter, Depends(get_adapter)]
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Custom exceptions for LLM adapters."""
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
"""Base exception for all LLM-related errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LLMAuthenticationError(LLMError):
|
||||
"""Raised when API authentication fails."""
|
||||
|
||||
def __init__(self, message: str = "LLM authentication failed"):
|
||||
super().__init__(message, status_code=401)
|
||||
|
||||
|
||||
class LLMRateLimitError(LLMError):
|
||||
"""Raised when rate limit is exceeded."""
|
||||
|
||||
def __init__(self, message: str = "LLM rate limit exceeded"):
|
||||
super().__init__(message, status_code=429)
|
||||
|
||||
|
||||
class LLMConnectionError(LLMError):
|
||||
"""Raised when connection to LLM service fails."""
|
||||
|
||||
def __init__(self, message: str = "Could not connect to LLM service"):
|
||||
super().__init__(message, status_code=503)
|
||||
|
||||
|
||||
class LLMConfigurationError(LLMError):
|
||||
"""Raised when LLM configuration is invalid."""
|
||||
|
||||
def __init__(self, message: str = "Invalid LLM configuration"):
|
||||
super().__init__(message, status_code=500)
|
||||
|
||||
|
||||
class LLMResponseError(LLMError):
|
||||
"""Raised when LLM returns an invalid or unexpected response."""
|
||||
|
||||
def __init__(self, message: str = "Invalid response from LLM"):
|
||||
super().__init__(message, status_code=502)
|
||||
|
||||
|
||||
def llm_exception_to_http(exc: LLMError) -> HTTPException:
|
||||
"""Convert an LLMError to a FastAPI HTTPException.
|
||||
|
||||
Args:
|
||||
exc: The LLMError to convert
|
||||
|
||||
Returns:
|
||||
An HTTPException with appropriate status code and detail
|
||||
"""
|
||||
return HTTPException(status_code=exc.status_code, detail=exc.message)
|
||||
+20
-5
@@ -4,7 +4,7 @@ import uuid
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from app.config import settings, MAX_MESSAGE_LENGTH
|
||||
from app.llm.adapter import get_adapter
|
||||
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
||||
from app.schemas import ChatRequest, ChatResponse, HealthResponse
|
||||
|
||||
# Configure logging
|
||||
@@ -29,7 +29,7 @@ async def health_check() -> HealthResponse:
|
||||
|
||||
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat(request: ChatRequest) -> ChatResponse:
|
||||
async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse:
|
||||
"""Process a chat message through the LLM adapter.
|
||||
|
||||
- Validates message length
|
||||
@@ -57,9 +57,19 @@ async def chat(request: ChatRequest) -> ChatResponse:
|
||||
},
|
||||
)
|
||||
|
||||
# Get adapter and generate response
|
||||
adapter = get_adapter()
|
||||
response_text = await adapter.generate(conversation_id, request.message)
|
||||
# Generate response with exception handling
|
||||
try:
|
||||
response_text = await adapter.generate(conversation_id, request.message)
|
||||
except LLMError as e:
|
||||
logger.error(
|
||||
"LLM generation failed",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": e.message,
|
||||
},
|
||||
)
|
||||
raise llm_exception_to_http(e)
|
||||
|
||||
# Log response metadata
|
||||
logger.info(
|
||||
@@ -77,3 +87,8 @@ async def chat(request: ChatRequest) -> ChatResponse:
|
||||
mode=settings.llm_mode,
|
||||
sources=[],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="127.0.0.1", port=8000, reload=True)
|
||||
|
||||
+1
-1
@@ -17,7 +17,7 @@ class ChatResponse(BaseModel):
|
||||
|
||||
conversation_id: str = Field(..., description="Conversation ID (generated if not provided)")
|
||||
response: str = Field(..., description="The LLM's response")
|
||||
mode: Literal["local", "remote"] = Field(..., description="Which adapter was used")
|
||||
mode: Literal["local", "remote", "openai"] = Field(..., description="Which adapter was used")
|
||||
sources: list = Field(default_factory=list, description="Source references (empty for now)")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user