tyndale-ai-service/app/llm/adapter.py

233 lines
7.6 KiB
Python

"""LLM adapter implementations with FastAPI dependency injection support."""
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Annotated
import httpx
from fastapi import Depends
from openai import (
AsyncOpenAI,
AuthenticationError,
RateLimitError,
APIConnectionError,
APIError,
)
from app.config import Settings, settings
from app.llm.exceptions import (
LLMError,
LLMAuthenticationError,
LLMRateLimitError,
LLMConnectionError,
LLMConfigurationError,
LLMResponseError,
)
class LLMAdapter(ABC):
"""Abstract base class for LLM adapters."""
@abstractmethod
async def generate(self, conversation_id: str, message: str) -> str:
"""Generate a response for the given message.
Args:
conversation_id: The conversation identifier
message: The user's message
Returns:
The generated response string
Raises:
LLMError: If generation fails for any reason
"""
pass
class LocalAdapter(LLMAdapter):
"""Local stub adapter for development and testing."""
async def generate(self, conversation_id: str, message: str) -> str:
"""Return a stub response echoing the user message.
This is a placeholder that will be replaced with a real local model.
"""
return (
f"[LOCAL STUB MODE] Acknowledged your message. "
f"You said: \"{message[:100]}{'...' if len(message) > 100 else ''}\". "
f"This is a stub response - local model not yet implemented."
)
class RemoteAdapter(LLMAdapter):
"""Remote adapter that calls an external LLM service via HTTP."""
def __init__(self, url: str, token: str | None = None, timeout: float = 30.0):
"""Initialize the remote adapter.
Args:
url: The remote LLM service URL
token: Optional bearer token for authentication
timeout: Request timeout in seconds
"""
self.url = url
self.token = token
self.timeout = timeout
async def generate(self, conversation_id: str, message: str) -> str:
"""Call the remote LLM service to generate a response.
Raises:
LLMConnectionError: If connection fails or times out
LLMAuthenticationError: If authentication fails
LLMRateLimitError: If rate limit is exceeded
LLMResponseError: If response is invalid
LLMError: For other HTTP errors
"""
headers = {"Content-Type": "application/json"}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
payload = {
"conversation_id": conversation_id,
"message": message,
}
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(self.url, json=payload, headers=headers)
if response.status_code == 401:
raise LLMAuthenticationError("Remote LLM authentication failed")
if response.status_code == 429:
raise LLMRateLimitError("Remote LLM rate limit exceeded")
if response.status_code != 200:
raise LLMError(
f"Remote LLM returned status {response.status_code}: "
f"{response.text[:200] if response.text else 'No response body'}",
status_code=response.status_code if 400 <= response.status_code < 600 else 502,
)
try:
data = response.json()
except ValueError:
raise LLMResponseError("Remote LLM returned invalid JSON response")
if "response" not in data:
raise LLMResponseError("Remote LLM response missing 'response' field")
return data["response"]
except httpx.TimeoutException:
raise LLMConnectionError(
f"Remote LLM request timed out after {self.timeout} seconds"
)
except httpx.ConnectError:
raise LLMConnectionError(f"Could not connect to remote LLM at {self.url}")
except httpx.RequestError as e:
raise LLMConnectionError(f"Remote LLM request failed: {str(e)}")
class OpenAIAdapter(LLMAdapter):
"""OpenAI API adapter using the official SDK with native async support."""
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
"""Initialize the OpenAI adapter.
Args:
api_key: OpenAI API key
model: Model identifier (e.g., "gpt-4o-mini", "gpt-4o")
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
async def generate(self, conversation_id: str, message: str) -> str:
"""Generate a response using the OpenAI API.
Args:
conversation_id: The conversation identifier (for future use with context)
message: The user's message
Returns:
The generated response string
Raises:
LLMAuthenticationError: If API key is invalid
LLMRateLimitError: If rate limit is exceeded
LLMConnectionError: If connection fails
LLMResponseError: If response content is empty
LLMError: For other API errors
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": message}],
)
content = response.choices[0].message.content
if content is None:
raise LLMResponseError("OpenAI returned empty response content")
return content
except AuthenticationError as e:
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
except RateLimitError as e:
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
except APIConnectionError as e:
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
except APIError as e:
raise LLMError(
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
)
# --- Dependency Injection Support ---
@lru_cache()
def get_settings() -> Settings:
"""Get cached settings instance for dependency injection."""
return settings
def get_adapter(settings: Annotated[Settings, Depends(get_settings)]) -> LLMAdapter:
"""Factory function to create the appropriate adapter based on configuration.
This function is designed for use with FastAPI's Depends() system.
Args:
settings: Application settings (injected by FastAPI)
Returns:
An LLMAdapter instance based on the LLM_MODE setting
Raises:
LLMConfigurationError: If configuration is invalid for the selected mode
"""
if settings.llm_mode == "openai":
if not settings.openai_api_key:
raise LLMConfigurationError(
"OPENAI_API_KEY must be set when LLM_MODE is 'openai'"
)
return OpenAIAdapter(
api_key=settings.openai_api_key,
model=settings.openai_model,
)
if settings.llm_mode == "remote":
if not settings.llm_remote_url:
raise LLMConfigurationError(
"LLM_REMOTE_URL must be set when LLM_MODE is 'remote'"
)
return RemoteAdapter(
url=settings.llm_remote_url,
token=settings.llm_remote_token or None,
)
return LocalAdapter()
# Type alias for clean dependency injection in endpoints
AdapterDependency = Annotated[LLMAdapter, Depends(get_adapter)]