"""OpenAI embedding client wrapper.""" import logging from functools import lru_cache from typing import Annotated from fastapi import Depends from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError from app.config import settings logger = logging.getLogger(__name__) class EmbeddingError(Exception): """Base exception for embedding operations.""" def __init__(self, message: str, status_code: int = 500): self.message = message self.status_code = status_code super().__init__(message) class EmbeddingClient: """Async wrapper for OpenAI embeddings API.""" def __init__(self, api_key: str, model: str = "text-embedding-3-small"): """Initialize the embedding client. Args: api_key: OpenAI API key model: Embedding model identifier """ self.client = AsyncOpenAI(api_key=api_key) self.model = model self.dimensions = 1536 # text-embedding-3-small dimension async def embed(self, text: str) -> list[float]: """Generate embedding for a single text. Args: text: Text to embed Returns: Embedding vector (1536 dimensions) Raises: EmbeddingError: If embedding generation fails """ try: response = await self.client.embeddings.create( model=self.model, input=text, ) return response.data[0].embedding except AuthenticationError as e: raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401) except RateLimitError as e: raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429) except APIConnectionError as e: raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503) except APIError as e: raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500) async def embed_batch(self, texts: list[str]) -> list[list[float]]: """Generate embeddings for multiple texts. Args: texts: List of texts to embed Returns: List of embedding vectors Raises: EmbeddingError: If embedding generation fails """ if not texts: return [] try: response = await self.client.embeddings.create( model=self.model, input=texts, ) # Sort by index to ensure correct ordering sorted_embeddings = sorted(response.data, key=lambda x: x.index) return [item.embedding for item in sorted_embeddings] except AuthenticationError as e: raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401) except RateLimitError as e: raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429) except APIConnectionError as e: raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503) except APIError as e: raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500) @lru_cache() def get_embedding_client() -> EmbeddingClient: """Get cached embedding client instance.""" return EmbeddingClient( api_key=settings.openai_api_key, model=settings.embedding_model, ) EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]