365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""LLM adapter implementations with FastAPI dependency injection support."""
|
|
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from functools import lru_cache
|
|
from collections.abc import AsyncIterator
|
|
from typing import Annotated
|
|
|
|
import httpx
|
|
from asksageclient import AskSageClient
|
|
from fastapi import Depends
|
|
from openai import (
|
|
AsyncOpenAI,
|
|
AuthenticationError,
|
|
RateLimitError,
|
|
APIConnectionError,
|
|
APIError,
|
|
)
|
|
|
|
from app.config import Settings, settings
|
|
from app.llm.exceptions import (
|
|
LLMError,
|
|
LLMAuthenticationError,
|
|
LLMRateLimitError,
|
|
LLMConnectionError,
|
|
LLMConfigurationError,
|
|
LLMResponseError,
|
|
)
|
|
|
|
|
|
class LLMAdapter(ABC):
|
|
"""Abstract base class for LLM adapters."""
|
|
|
|
@abstractmethod
|
|
async def generate(self, conversation_id: str, message: str) -> str:
|
|
"""Generate a response for the given message.
|
|
|
|
Args:
|
|
conversation_id: The conversation identifier
|
|
message: The user's message
|
|
|
|
Returns:
|
|
The generated response string
|
|
|
|
Raises:
|
|
LLMError: If generation fails for any reason
|
|
"""
|
|
pass
|
|
|
|
async def generate_stream(
|
|
self, conversation_id: str, message: str
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a response for the given message.
|
|
|
|
Default implementation yields the full response as a single chunk.
|
|
Subclasses can override this to provide true streaming.
|
|
|
|
Args:
|
|
conversation_id: The conversation identifier
|
|
message: The user's message
|
|
|
|
Yields:
|
|
Response content chunks
|
|
|
|
Raises:
|
|
LLMError: If generation fails for any reason
|
|
"""
|
|
response = await self.generate(conversation_id, message)
|
|
yield response
|
|
|
|
|
|
class LocalAdapter(LLMAdapter):
|
|
"""Local stub adapter for development and testing."""
|
|
|
|
async def generate(self, conversation_id: str, message: str) -> str:
|
|
"""Return a stub response echoing the user message.
|
|
|
|
This is a placeholder that will be replaced with a real local model.
|
|
"""
|
|
return (
|
|
f"[LOCAL STUB MODE] Acknowledged your message. "
|
|
f"You said: \"{message[:100]}{'...' if len(message) > 100 else ''}\". "
|
|
f"This is a stub response - local model not yet implemented."
|
|
)
|
|
|
|
|
|
class RemoteAdapter(LLMAdapter):
|
|
"""Remote adapter that calls an external LLM service via HTTP."""
|
|
|
|
def __init__(self, url: str, token: str | None = None, timeout: float = 30.0):
|
|
"""Initialize the remote adapter.
|
|
|
|
Args:
|
|
url: The remote LLM service URL
|
|
token: Optional bearer token for authentication
|
|
timeout: Request timeout in seconds
|
|
"""
|
|
self.url = url
|
|
self.token = token
|
|
self.timeout = timeout
|
|
|
|
async def generate(self, conversation_id: str, message: str) -> str:
|
|
"""Call the remote LLM service to generate a response.
|
|
|
|
Raises:
|
|
LLMConnectionError: If connection fails or times out
|
|
LLMAuthenticationError: If authentication fails
|
|
LLMRateLimitError: If rate limit is exceeded
|
|
LLMResponseError: If response is invalid
|
|
LLMError: For other HTTP errors
|
|
"""
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.token:
|
|
headers["Authorization"] = f"Bearer {self.token}"
|
|
|
|
payload = {
|
|
"conversation_id": conversation_id,
|
|
"message": message,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
|
response = await client.post(self.url, json=payload, headers=headers)
|
|
|
|
if response.status_code == 401:
|
|
raise LLMAuthenticationError("Remote LLM authentication failed")
|
|
if response.status_code == 429:
|
|
raise LLMRateLimitError("Remote LLM rate limit exceeded")
|
|
if response.status_code != 200:
|
|
raise LLMError(
|
|
f"Remote LLM returned status {response.status_code}: "
|
|
f"{response.text[:200] if response.text else 'No response body'}",
|
|
status_code=response.status_code if 400 <= response.status_code < 600 else 502,
|
|
)
|
|
|
|
try:
|
|
data = response.json()
|
|
except ValueError:
|
|
raise LLMResponseError("Remote LLM returned invalid JSON response")
|
|
|
|
if "response" not in data:
|
|
raise LLMResponseError("Remote LLM response missing 'response' field")
|
|
|
|
return data["response"]
|
|
|
|
except httpx.TimeoutException:
|
|
raise LLMConnectionError(
|
|
f"Remote LLM request timed out after {self.timeout} seconds"
|
|
)
|
|
except httpx.ConnectError:
|
|
raise LLMConnectionError(f"Could not connect to remote LLM at {self.url}")
|
|
except httpx.RequestError as e:
|
|
raise LLMConnectionError(f"Remote LLM request failed: {str(e)}")
|
|
|
|
|
|
class OpenAIAdapter(LLMAdapter):
|
|
"""OpenAI API adapter using the official SDK with native async support."""
|
|
|
|
def __init__(self, api_key: str, model: str = "gpt-4o-mini"):
|
|
"""Initialize the OpenAI adapter.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
model: Model identifier (e.g., "gpt-4o-mini", "gpt-4o")
|
|
"""
|
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
self.model = model
|
|
|
|
async def generate(self, conversation_id: str, message: str) -> str:
|
|
"""Generate a response using the OpenAI API.
|
|
|
|
Args:
|
|
conversation_id: The conversation identifier (for future use with context)
|
|
message: The user's message
|
|
|
|
Returns:
|
|
The generated response string
|
|
|
|
Raises:
|
|
LLMAuthenticationError: If API key is invalid
|
|
LLMRateLimitError: If rate limit is exceeded
|
|
LLMConnectionError: If connection fails
|
|
LLMResponseError: If response content is empty
|
|
LLMError: For other API errors
|
|
"""
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[{"role": "user", "content": message}],
|
|
)
|
|
|
|
content = response.choices[0].message.content
|
|
if content is None:
|
|
raise LLMResponseError("OpenAI returned empty response content")
|
|
return content
|
|
|
|
except AuthenticationError as e:
|
|
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
|
|
except RateLimitError as e:
|
|
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
|
|
except APIConnectionError as e:
|
|
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
|
|
except APIError as e:
|
|
raise LLMError(
|
|
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
|
)
|
|
|
|
async def generate_stream(
|
|
self, conversation_id: str, message: str
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a response using the OpenAI API.
|
|
|
|
Args:
|
|
conversation_id: The conversation identifier (for future use with context)
|
|
message: The user's message
|
|
|
|
Yields:
|
|
Response content chunks
|
|
|
|
Raises:
|
|
LLMAuthenticationError: If API key is invalid
|
|
LLMRateLimitError: If rate limit is exceeded
|
|
LLMConnectionError: If connection fails
|
|
LLMError: For other API errors
|
|
"""
|
|
try:
|
|
stream = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[{"role": "user", "content": message}],
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
yield chunk.choices[0].delta.content
|
|
|
|
except AuthenticationError as e:
|
|
raise LLMAuthenticationError(f"OpenAI authentication failed: {e.message}")
|
|
except RateLimitError as e:
|
|
raise LLMRateLimitError(f"OpenAI rate limit exceeded: {e.message}")
|
|
except APIConnectionError as e:
|
|
raise LLMConnectionError(f"Could not connect to OpenAI: {str(e)}")
|
|
except APIError as e:
|
|
raise LLMError(
|
|
f"OpenAI API error: {e.message}", status_code=e.status_code or 500
|
|
)
|
|
|
|
|
|
class AskSageAdapter(LLMAdapter):
|
|
"""AskSage API adapter using the official asksageclient SDK."""
|
|
|
|
def __init__(self, email: str, api_key: str, model: str = "gpt-4o"):
|
|
"""Initialize the AskSage adapter.
|
|
|
|
Args:
|
|
email: AskSage account email
|
|
api_key: AskSage API key
|
|
model: Model identifier (e.g., "gpt-4o", "claude-3-opus")
|
|
"""
|
|
self.client = AskSageClient(email=email, api_key=api_key)
|
|
self.model = model
|
|
|
|
async def generate(self, conversation_id: str, message: str) -> str:
|
|
"""Generate a response using the AskSage API.
|
|
|
|
Args:
|
|
conversation_id: The conversation identifier (for future use)
|
|
message: The user's message
|
|
|
|
Returns:
|
|
The generated response string
|
|
|
|
Raises:
|
|
LLMAuthenticationError: If credentials are invalid
|
|
LLMConnectionError: If connection fails
|
|
LLMResponseError: If response content is empty or invalid
|
|
LLMError: For other API errors
|
|
"""
|
|
try:
|
|
# AskSageClient is synchronous, run in thread pool to avoid blocking
|
|
response = await asyncio.to_thread(
|
|
self.client.query,
|
|
message=message,
|
|
model=self.model,
|
|
)
|
|
|
|
if not isinstance(response, dict):
|
|
raise LLMResponseError("AskSage returned invalid response format")
|
|
|
|
content = response.get("response")
|
|
if content is None:
|
|
raise LLMResponseError("AskSage returned empty response content")
|
|
return content
|
|
|
|
except LLMError:
|
|
raise
|
|
except Exception as e:
|
|
error_msg = str(e).lower()
|
|
if "auth" in error_msg or "401" in error_msg or "unauthorized" in error_msg:
|
|
raise LLMAuthenticationError(f"AskSage authentication failed: {e}")
|
|
if "connect" in error_msg or "timeout" in error_msg:
|
|
raise LLMConnectionError(f"Could not connect to AskSage: {e}")
|
|
raise LLMError(f"AskSage API error: {e}")
|
|
|
|
|
|
# --- Dependency Injection Support ---
|
|
|
|
|
|
@lru_cache()
|
|
def get_settings() -> Settings:
|
|
"""Get cached settings instance for dependency injection."""
|
|
return settings
|
|
|
|
|
|
def get_adapter(settings: Annotated[Settings, Depends(get_settings)]) -> LLMAdapter:
|
|
"""Factory function to create the appropriate adapter based on configuration.
|
|
|
|
This function is designed for use with FastAPI's Depends() system.
|
|
|
|
Args:
|
|
settings: Application settings (injected by FastAPI)
|
|
|
|
Returns:
|
|
An LLMAdapter instance based on the LLM_MODE setting
|
|
|
|
Raises:
|
|
LLMConfigurationError: If configuration is invalid for the selected mode
|
|
"""
|
|
if settings.llm_mode == "openai":
|
|
if not settings.openai_api_key:
|
|
raise LLMConfigurationError(
|
|
"OPENAI_API_KEY must be set when LLM_MODE is 'openai'"
|
|
)
|
|
return OpenAIAdapter(
|
|
api_key=settings.openai_api_key,
|
|
model=settings.openai_model,
|
|
)
|
|
|
|
if settings.llm_mode == "remote":
|
|
if not settings.llm_remote_url:
|
|
raise LLMConfigurationError(
|
|
"LLM_REMOTE_URL must be set when LLM_MODE is 'remote'"
|
|
)
|
|
return RemoteAdapter(
|
|
url=settings.llm_remote_url,
|
|
token=settings.llm_remote_token or None,
|
|
)
|
|
|
|
if settings.llm_mode == "asksage":
|
|
if not settings.asksage_email or not settings.asksage_api_key:
|
|
raise LLMConfigurationError(
|
|
"ASKSAGE_EMAIL and ASKSAGE_API_KEY must be set when LLM_MODE is 'asksage'"
|
|
)
|
|
return AskSageAdapter(
|
|
email=settings.asksage_email,
|
|
api_key=settings.asksage_api_key,
|
|
model=settings.asksage_model,
|
|
)
|
|
|
|
return LocalAdapter()
|
|
|
|
|
|
# Type alias for clean dependency injection in endpoints
|
|
AdapterDependency = Annotated[LLMAdapter, Depends(get_adapter)]
|