110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
"""OpenAI embedding client wrapper."""
|
|
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import Annotated
|
|
|
|
from fastapi import Depends
|
|
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EmbeddingError(Exception):
|
|
"""Base exception for embedding operations."""
|
|
|
|
def __init__(self, message: str, status_code: int = 500):
|
|
self.message = message
|
|
self.status_code = status_code
|
|
super().__init__(message)
|
|
|
|
|
|
class EmbeddingClient:
|
|
"""Async wrapper for OpenAI embeddings API."""
|
|
|
|
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
|
|
"""Initialize the embedding client.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
model: Embedding model identifier
|
|
"""
|
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
self.model = model
|
|
self.dimensions = 1536 # text-embedding-3-small dimension
|
|
|
|
async def embed(self, text: str) -> list[float]:
|
|
"""Generate embedding for a single text.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector (1536 dimensions)
|
|
|
|
Raises:
|
|
EmbeddingError: If embedding generation fails
|
|
"""
|
|
try:
|
|
response = await self.client.embeddings.create(
|
|
model=self.model,
|
|
input=text,
|
|
)
|
|
return response.data[0].embedding
|
|
|
|
except AuthenticationError as e:
|
|
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
|
except RateLimitError as e:
|
|
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
|
except APIConnectionError as e:
|
|
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
|
except APIError as e:
|
|
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
|
|
|
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""Generate embeddings for multiple texts.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
|
|
Raises:
|
|
EmbeddingError: If embedding generation fails
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
try:
|
|
response = await self.client.embeddings.create(
|
|
model=self.model,
|
|
input=texts,
|
|
)
|
|
# Sort by index to ensure correct ordering
|
|
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
|
|
return [item.embedding for item in sorted_embeddings]
|
|
|
|
except AuthenticationError as e:
|
|
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
|
|
except RateLimitError as e:
|
|
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
|
|
except APIConnectionError as e:
|
|
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
|
|
except APIError as e:
|
|
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
|
|
|
|
|
|
@lru_cache()
|
|
def get_embedding_client() -> EmbeddingClient:
|
|
"""Get cached embedding client instance."""
|
|
return EmbeddingClient(
|
|
api_key=settings.openai_api_key,
|
|
model=settings.embedding_model,
|
|
)
|
|
|
|
|
|
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]
|