Merge pull request #5 from dannyjosephgarcia/WOOL-25

feat: add GCP service-to-service authentication
This commit is contained in:
Danny Garcia 2026-01-20 11:42:10 -06:00 committed by GitHub
commit 50599d7cee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 2 deletions

114
app/auth.py Normal file
View File

@ -0,0 +1,114 @@
"""GCP Cloud Run service-to-service authentication module."""
import logging
from typing import Annotated
from fastapi import Depends, HTTPException, Request
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
from app.config import settings
logger = logging.getLogger(__name__)
def verify_gcp_identity_token(token: str, audience: str) -> dict:
"""Verify a GCP identity token and return the decoded claims.
Args:
token: The identity token to verify.
audience: The expected audience (backend Cloud Run URL).
Returns:
The decoded token claims.
Raises:
ValueError: If the token is invalid or verification fails.
"""
try:
claims = id_token.verify_oauth2_token(
token,
google_requests.Request(),
audience=audience,
)
return claims
except Exception as e:
logger.warning(f"Token verification failed: {e}")
raise ValueError(f"Token verification failed: {e}")
async def verify_service_auth(request: Request) -> dict | None:
"""FastAPI dependency to verify GCP service-to-service authentication.
Returns None if auth is disabled (local dev), otherwise verifies the
identity token and checks the service account allowlist.
Returns:
The decoded token claims, or None if auth is disabled.
Raises:
HTTPException: 401 if authentication fails.
"""
# Skip auth if disabled (local development)
if not settings.auth_enabled:
logger.debug("Authentication disabled, skipping verification")
return None
# Extract token from Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
logger.warning("Missing Authorization header")
raise HTTPException(
status_code=401,
detail="Missing Authorization header",
)
if not auth_header.startswith("Bearer "):
logger.warning("Invalid Authorization header format")
raise HTTPException(
status_code=401,
detail="Invalid Authorization header format. Expected 'Bearer <token>'",
)
token = auth_header[7:] # Remove "Bearer " prefix
# Verify the token
if not settings.auth_audience:
logger.error("AUTH_AUDIENCE not configured")
raise HTTPException(
status_code=500,
detail="Server authentication not properly configured",
)
try:
claims = verify_gcp_identity_token(token, settings.auth_audience)
except ValueError as e:
logger.warning(f"Token verification failed: {e}")
raise HTTPException(
status_code=401,
detail="Invalid or expired token",
)
# Check service account allowlist if configured
allowed_accounts = settings.allowed_service_accounts_list
if allowed_accounts:
email = claims.get("email", "")
if email not in allowed_accounts:
logger.warning(
f"Service account '{email}' not in allowlist",
extra={"allowed": allowed_accounts},
)
raise HTTPException(
status_code=403,
detail="Service account not authorized",
)
logger.info(
"Service authentication successful",
extra={"service_account": claims.get("email", "unknown")},
)
return claims
# Type alias for clean dependency injection
ServiceAuthDependency = Annotated[dict | None, Depends(verify_service_auth)]

View File

@ -22,11 +22,21 @@ class Settings(BaseSettings):
# CORS configuration # CORS configuration
cors_origins: str = "http://localhost:3000" cors_origins: str = "http://localhost:3000"
# Authentication settings
auth_enabled: bool = False # Set to True in production
auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app)
allowed_service_accounts: str = "" # Comma-separated list of allowed service account emails
@property @property
def cors_origins_list(self) -> list[str]: def cors_origins_list(self) -> list[str]:
"""Parse comma-separated CORS origins into a list.""" """Parse comma-separated CORS origins into a list."""
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()] return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
@property
def allowed_service_accounts_list(self) -> list[str]:
"""Parse comma-separated allowed service accounts into a list."""
return [sa.strip() for sa in self.allowed_service_accounts.split(",") if sa.strip()]
class Config: class Config:
env_file = ".env" env_file = ".env"
env_file_encoding = "utf-8" env_file_encoding = "utf-8"

View File

@ -5,6 +5,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from app.auth import ServiceAuthDependency
from app.config import settings, MAX_MESSAGE_LENGTH from app.config import settings, MAX_MESSAGE_LENGTH
from app.llm import AdapterDependency, LLMError, llm_exception_to_http from app.llm import AdapterDependency, LLMError, llm_exception_to_http
from app.schemas import ( from app.schemas import (
@ -47,7 +48,11 @@ async def health_check() -> HealthResponse:
@app.post("/chat", response_model=ChatResponse) @app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse: async def chat(
request: ChatRequest,
adapter: AdapterDependency,
_auth: ServiceAuthDependency,
) -> ChatResponse:
"""Process a chat message through the LLM adapter. """Process a chat message through the LLM adapter.
- Validates message length - Validates message length
@ -108,7 +113,11 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse
@app.post("/chat/stream") @app.post("/chat/stream")
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse: async def chat_stream(
request: ChatRequest,
adapter: AdapterDependency,
_auth: ServiceAuthDependency,
) -> StreamingResponse:
"""Stream a chat response through the LLM adapter using Server-Sent Events. """Stream a chat response through the LLM adapter using Server-Sent Events.
- Validates message length - Validates message length

View File

@ -8,6 +8,7 @@ click==8.3.1
colorama==0.4.6 colorama==0.4.6
distro==1.9.0 distro==1.9.0
fastapi==0.128.0 fastapi==0.128.0
google-auth>=2.20.0
h11==0.16.0 h11==0.16.0
httpcore==1.0.9 httpcore==1.0.9
httptools==0.7.1 httptools==0.7.1