99 lines
3.6 KiB
Python
99 lines
3.6 KiB
Python
"""Lightweight intent classification using gpt-4o-mini."""
|
|
|
|
import logging
|
|
from functools import lru_cache
|
|
from typing import Annotated, Literal
|
|
|
|
from fastapi import Depends
|
|
from openai import AsyncOpenAI
|
|
|
|
from app.config import settings
|
|
from app.memory.conversation import Message
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
Intent = Literal["codebase", "general", "clarification"]
|
|
|
|
INTENT_PROMPT = """You are classifying questions for a Tyndale trading system documentation assistant.
|
|
|
|
Classify this user message into one category:
|
|
- "codebase": ANY question about trading, strategies, exchanges, orders, positions, risk, execution, hedging, market making, P&L, or how the system works. Also includes questions with "our", "the system", "this", or references to specific functionality.
|
|
- "general": ONLY greetings or completely off-topic questions ("How are you?", "What's the weather?", "Hello")
|
|
- "clarification": Follow-ups that reference previous answers ("Tell me more", "What did you mean?", "Can you explain that?")
|
|
|
|
DEFAULT TO "codebase" if uncertain. This is a trading system assistant - assume trading questions are about the codebase.
|
|
|
|
Respond with ONLY the category name, nothing else."""
|
|
|
|
|
|
class IntentClassifier:
|
|
"""Lightweight intent classifier using gpt-4o-mini."""
|
|
|
|
def __init__(self, api_key: str):
|
|
"""Initialize the classifier.
|
|
|
|
Args:
|
|
api_key: OpenAI API key
|
|
"""
|
|
self.client = AsyncOpenAI(api_key=api_key)
|
|
self.model = "gpt-4o-mini"
|
|
|
|
async def classify(
|
|
self,
|
|
message: str,
|
|
history: list[Message] | None = None,
|
|
) -> Intent:
|
|
"""Classify user message intent.
|
|
|
|
Args:
|
|
message: User's message
|
|
history: Optional conversation history for context
|
|
|
|
Returns:
|
|
Classified intent: "codebase", "general", or "clarification"
|
|
"""
|
|
# Build context from history (last 2 turns)
|
|
context = ""
|
|
if history and len(history) >= 2:
|
|
recent = history[-4:] # Last 2 exchanges
|
|
context = "Recent conversation:\n"
|
|
for msg in recent:
|
|
role = "User" if msg.role == "user" else "Assistant"
|
|
context += f"{role}: {msg.content[:100]}...\n" if len(msg.content) > 100 else f"{role}: {msg.content}\n"
|
|
context += "\n"
|
|
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": INTENT_PROMPT},
|
|
{"role": "user", "content": f"{context}Current message: {message}"},
|
|
],
|
|
max_tokens=10,
|
|
temperature=0,
|
|
)
|
|
|
|
raw_intent = response.choices[0].message.content.strip().lower()
|
|
|
|
# Validate intent
|
|
if raw_intent in ("codebase", "general", "clarification"):
|
|
logger.info(f"Intent classified: '{message[:50]}...' -> {raw_intent}")
|
|
return raw_intent
|
|
|
|
# Default to codebase for ambiguous cases (safer for RAG)
|
|
logger.warning(f"Unexpected intent response: {raw_intent}, defaulting to codebase")
|
|
return "codebase"
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Intent classification failed: {e}, defaulting to codebase")
|
|
return "codebase"
|
|
|
|
|
|
@lru_cache()
|
|
def get_intent_classifier() -> IntentClassifier:
|
|
"""Get cached intent classifier instance."""
|
|
return IntentClassifier(api_key=settings.openai_api_key)
|
|
|
|
|
|
IntentClassifierDependency = Annotated[IntentClassifier, Depends(get_intent_classifier)]
|