feat: add GCP service-to-service authentication
Implement identity token verification for Cloud Run deployments: - Add auth module with GCP identity token verification - Add configurable auth settings (AUTH_ENABLED, AUTH_AUDIENCE) - Add service account allowlist for access control - Protect /chat and /chat/stream endpoints with auth dependency - Add google-auth dependency for token verification Auth can be disabled for local development via AUTH_ENABLED=false.
This commit is contained in:
parent
e3c4680108
commit
c9336d1d84
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""GCP Cloud Run service-to-service authentication module."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from google.auth.transport import requests as google_requests
|
||||||
|
from google.oauth2 import id_token
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_gcp_identity_token(token: str, audience: str) -> dict:
|
||||||
|
"""Verify a GCP identity token and return the decoded claims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The identity token to verify.
|
||||||
|
audience: The expected audience (backend Cloud Run URL).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded token claims.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the token is invalid or verification fails.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
claims = id_token.verify_oauth2_token(
|
||||||
|
token,
|
||||||
|
google_requests.Request(),
|
||||||
|
audience=audience,
|
||||||
|
)
|
||||||
|
return claims
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Token verification failed: {e}")
|
||||||
|
raise ValueError(f"Token verification failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_service_auth(request: Request) -> dict | None:
|
||||||
|
"""FastAPI dependency to verify GCP service-to-service authentication.
|
||||||
|
|
||||||
|
Returns None if auth is disabled (local dev), otherwise verifies the
|
||||||
|
identity token and checks the service account allowlist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded token claims, or None if auth is disabled.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 if authentication fails.
|
||||||
|
"""
|
||||||
|
# Skip auth if disabled (local development)
|
||||||
|
if not settings.auth_enabled:
|
||||||
|
logger.debug("Authentication disabled, skipping verification")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract token from Authorization header
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if not auth_header:
|
||||||
|
logger.warning("Missing Authorization header")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Missing Authorization header",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
logger.warning("Invalid Authorization header format")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Invalid Authorization header format. Expected 'Bearer <token>'",
|
||||||
|
)
|
||||||
|
|
||||||
|
token = auth_header[7:] # Remove "Bearer " prefix
|
||||||
|
|
||||||
|
# Verify the token
|
||||||
|
if not settings.auth_audience:
|
||||||
|
logger.error("AUTH_AUDIENCE not configured")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Server authentication not properly configured",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
claims = verify_gcp_identity_token(token, settings.auth_audience)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(f"Token verification failed: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Invalid or expired token",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check service account allowlist if configured
|
||||||
|
allowed_accounts = settings.allowed_service_accounts_list
|
||||||
|
if allowed_accounts:
|
||||||
|
email = claims.get("email", "")
|
||||||
|
if email not in allowed_accounts:
|
||||||
|
logger.warning(
|
||||||
|
f"Service account '{email}' not in allowlist",
|
||||||
|
extra={"allowed": allowed_accounts},
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Service account not authorized",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Service authentication successful",
|
||||||
|
extra={"service_account": claims.get("email", "unknown")},
|
||||||
|
)
|
||||||
|
return claims
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for clean dependency injection
|
||||||
|
ServiceAuthDependency = Annotated[dict | None, Depends(verify_service_auth)]
|
||||||
|
|
@ -22,11 +22,21 @@ class Settings(BaseSettings):
|
||||||
# CORS configuration
|
# CORS configuration
|
||||||
cors_origins: str = "http://localhost:3000"
|
cors_origins: str = "http://localhost:3000"
|
||||||
|
|
||||||
|
# Authentication settings
|
||||||
|
auth_enabled: bool = False # Set to True in production
|
||||||
|
auth_audience: str = "" # Backend Cloud Run URL (e.g., https://backend-xxx.run.app)
|
||||||
|
allowed_service_accounts: str = "" # Comma-separated list of allowed service account emails
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cors_origins_list(self) -> list[str]:
|
def cors_origins_list(self) -> list[str]:
|
||||||
"""Parse comma-separated CORS origins into a list."""
|
"""Parse comma-separated CORS origins into a list."""
|
||||||
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
return [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_service_accounts_list(self) -> list[str]:
|
||||||
|
"""Parse comma-separated allowed service accounts into a list."""
|
||||||
|
return [sa.strip() for sa in self.allowed_service_accounts.split(",") if sa.strip()]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = "utf-8"
|
||||||
|
|
|
||||||
13
app/main.py
13
app/main.py
|
|
@ -5,6 +5,7 @@ from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from app.auth import ServiceAuthDependency
|
||||||
from app.config import settings, MAX_MESSAGE_LENGTH
|
from app.config import settings, MAX_MESSAGE_LENGTH
|
||||||
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
from app.llm import AdapterDependency, LLMError, llm_exception_to_http
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
|
|
@ -47,7 +48,11 @@ async def health_check() -> HealthResponse:
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat", response_model=ChatResponse)
|
@app.post("/chat", response_model=ChatResponse)
|
||||||
async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse:
|
async def chat(
|
||||||
|
request: ChatRequest,
|
||||||
|
adapter: AdapterDependency,
|
||||||
|
_auth: ServiceAuthDependency,
|
||||||
|
) -> ChatResponse:
|
||||||
"""Process a chat message through the LLM adapter.
|
"""Process a chat message through the LLM adapter.
|
||||||
|
|
||||||
- Validates message length
|
- Validates message length
|
||||||
|
|
@ -108,7 +113,11 @@ async def chat(request: ChatRequest, adapter: AdapterDependency) -> ChatResponse
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat/stream")
|
@app.post("/chat/stream")
|
||||||
async def chat_stream(request: ChatRequest, adapter: AdapterDependency) -> StreamingResponse:
|
async def chat_stream(
|
||||||
|
request: ChatRequest,
|
||||||
|
adapter: AdapterDependency,
|
||||||
|
_auth: ServiceAuthDependency,
|
||||||
|
) -> StreamingResponse:
|
||||||
"""Stream a chat response through the LLM adapter using Server-Sent Events.
|
"""Stream a chat response through the LLM adapter using Server-Sent Events.
|
||||||
|
|
||||||
- Validates message length
|
- Validates message length
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ click==8.3.1
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
fastapi==0.128.0
|
fastapi==0.128.0
|
||||||
|
google-auth>=2.20.0
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
httptools==0.7.1
|
httptools==0.7.1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue