tyndale-ai-service/app/embeddings/client.py

110 lines
3.4 KiB
Python

"""OpenAI embedding client wrapper."""
import logging
from functools import lru_cache
from typing import Annotated
from fastapi import Depends
from openai import AsyncOpenAI, AuthenticationError, RateLimitError, APIConnectionError, APIError
from app.config import settings
logger = logging.getLogger(__name__)
class EmbeddingError(Exception):
"""Base exception for embedding operations."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(message)
class EmbeddingClient:
"""Async wrapper for OpenAI embeddings API."""
def __init__(self, api_key: str, model: str = "text-embedding-3-small"):
"""Initialize the embedding client.
Args:
api_key: OpenAI API key
model: Embedding model identifier
"""
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
self.dimensions = 1536 # text-embedding-3-small dimension
async def embed(self, text: str) -> list[float]:
"""Generate embedding for a single text.
Args:
text: Text to embed
Returns:
Embedding vector (1536 dimensions)
Raises:
EmbeddingError: If embedding generation fails
"""
try:
response = await self.client.embeddings.create(
model=self.model,
input=text,
)
return response.data[0].embedding
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Generate embeddings for multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding generation fails
"""
if not texts:
return []
try:
response = await self.client.embeddings.create(
model=self.model,
input=texts,
)
# Sort by index to ensure correct ordering
sorted_embeddings = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_embeddings]
except AuthenticationError as e:
raise EmbeddingError(f"OpenAI authentication failed: {e.message}", 401)
except RateLimitError as e:
raise EmbeddingError(f"OpenAI rate limit exceeded: {e.message}", 429)
except APIConnectionError as e:
raise EmbeddingError(f"Could not connect to OpenAI: {str(e)}", 503)
except APIError as e:
raise EmbeddingError(f"OpenAI API error: {e.message}", e.status_code or 500)
@lru_cache()
def get_embedding_client() -> EmbeddingClient:
"""Get cached embedding client instance."""
return EmbeddingClient(
api_key=settings.openai_api_key,
model=settings.embedding_model,
)
EmbeddingClientDependency = Annotated[EmbeddingClient, Depends(get_embedding_client)]