From c9336d1d8442766e8233511480e6ac9554123dd2 Mon Sep 17 00:00:00 2001 From: Danny Date: Mon, 19 Jan 2026 11:06:59 -0600 Subject: [PATCH] feat: add GCP service-to-service authentication Implement identity token verification for Cloud Run deployments: - Add auth module with GCP identity token verification - Add configurable auth settings (AUTH_ENABLED, AUTH_AUDIENCE) - Add service account allowlist for access control - Protect /chat and /chat/stream endpoints with auth dependency - Add google-auth dependency for token verification Auth can be disabled for local development via AUTH_ENABLED=false. --- app/auth.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++ app/config.py | 10 +++++ app/main.py | 13 +++++- requirements.txt | 1 + 4 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 app/auth.py diff --git a/app/auth.py b/app/auth.py new file mode 100644 index 0000000..a002e1d --- /dev/null +++ b/app/auth.py @@ -0,0 +1,114 @@ +"""GCP Cloud Run service-to-service authentication module.""" + +import logging +from typing import Annotated + +from fastapi import Depends, HTTPException, Request +from google.auth.transport import requests as google_requests +from google.oauth2 import id_token + +from app.config import settings + +logger = logging.getLogger(__name__) + + +def verify_gcp_identity_token(token: str, audience: str) -> dict: + """Verify a GCP identity token and return the decoded claims. + + Args: + token: The identity token to verify. + audience: The expected audience (backend Cloud Run URL). + + Returns: + The decoded token claims. + + Raises: + ValueError: If the token is invalid or verification fails. + """ + try: + claims = id_token.verify_oauth2_token( + token, + google_requests.Request(), + audience=audience, + ) + return claims + except Exception as e: + logger.warning(f"Token verification failed: {e}") + raise ValueError(f"Token verification failed: {e}") + + +async def verify_service_auth(request: Request) -> dict | None: + """FastAPI dependency to verify GCP service-to-service authentication. + + Returns None if auth is disabled (local dev), otherwise verifies the + identity token and checks the service account allowlist. + + Returns: + The decoded token claims, or None if auth is disabled. + + Raises: + HTTPException: 401 if authentication fails. + """ + # Skip auth if disabled (local development) + if not settings.auth_enabled: + logger.debug("Authentication disabled, skipping verification") + return None + + # Extract token from Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header: + logger.warning("Missing Authorization header") + raise HTTPException( + status_code=401, + detail="Missing Authorization header", + ) + + if not auth_header.startswith("Bearer "): + logger.warning("Invalid Authorization header format") + raise HTTPException( + status_code=401, + detail="Invalid Authorization header format. Expected 'Bearer '", + ) + + token = auth_header[7:] # Remove "Bearer " prefix + + # Verify the token + if not settings.auth_audience: + logger.error("AUTH_AUDIENCE not configured") + raise HTTPException( + status_code=500, + detail="Server authentication not properly configured", + ) + + try: + claims = verify_gcp_identity_token(token, settings.auth_audience) + except ValueError as e: + logger.warning(f"Token verification failed: {e}") + raise HTTPException( + status_code=401, + detail="Invalid or expired token", + ) + + # Check service account allowlist if configured + allowed_accounts = settings.allowed_service_accounts_list + if allowed_accounts: + email = claims.get("email", "") + if email not in allowed_accounts: + logger.warning( + f"Service account '{email}' not in allowlist", + extra={"allowed": allowed_accounts}, + ) + raise HTTPException( + status_code=403, + detail="Service account not authorized", + ) + + logger.info( + "Service authentication successful", + extra={"service_account": claims.get("email", "unknown")}, + ) + return claims + + +# Type alias for clean dependency injection +ServiceAuthDependency = Annotated[dict | None, Depends(verify_service_auth)] diff --git a/app/config.py b/app/config.py index e1af224..93ba45f 100644 --- a/app/config.py +++ b/app/config.py @@ -22,11 +22,21 @@ class Settings(BaseSettings): # CORS configuration cors_origins: str = "http://localhost:3000" + # Authentication settings + auth_enabled: bool = False # Set to True in production + auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app) + allowed_service_accounts: str = "" # Comma-separated list of allowed service account emails + @property def cors_origins_list(self) -> list[str]: """Parse comma-separated CORS origins into a list.""" return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()] + @property + def allowed_service_accounts_list(self) -> list[str]: + """Parse comma-separated allowed service accounts into a list.""" + return [sa.strip() for sa in self.allowed_service_accounts.split(",") if sa.strip()] + class Config: env_file = ".env" env_file_encoding = "utf-8" diff --git a/app/main.py b/app/main.py index cae6f15..175b12a 100644 --- a/app/main.py +++ b/app/main.py @@ -5,6 +5,7 @@ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse +from app.auth import ServiceAuthDependency from app.config import settings, MAX_MESSAGE_LENGTH from app.llm import AdapterDependency, LLMError, llm_exception_to_http from app.schemas import ( @@ -47,7 +48,11 @@ async def health_check() -> HealthResponse: @app.post("/chat", response_model=ChatResponse) -async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse: +async def chat( + request: ChatRequest, + adapter: AdapterDependency, + _auth: ServiceAuthDependency, +) -> ChatResponse: """Process a chat message through the LLM adapter. - Validates message length @@ -108,7 +113,11 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse @app.post("/chat/stream") -async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse: +async def chat_stream( + request: ChatRequest, + adapter: AdapterDependency, + _auth: ServiceAuthDependency, +) -> StreamingResponse: """Stream a chat response through the LLM adapter using Server-Sent Events. - Validates message length diff --git a/requirements.txt b/requirements.txt index ef4b621..5066cb7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ click==8.3.1 colorama==0.4.6 distro==1.9.0 fastapi==0.128.0 +google-auth>=2.20.0 h11==0.16.0 httpcore==1.0.9 httptools==0.7.1