Files
ai-gateway/app/api/deps.py

83 lines
2.7 KiB
Python

from fastapi import Security, HTTPException, status, Depends
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.database import get_db
from app.models.module import Module
from cachetools import TTLCache
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
# Cache keys for 5 minutes, store up to 1000 keys
# This prevents a Supabase round-trip on every message
auth_cache = TTLCache(maxsize=1000, ttl=300)
async def get_api_key(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
db: Session = Depends(get_db)
):
# Use header if provided, otherwise fallback to query param
api_key = api_key_header or api_key_query
if not api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API Key missing"
)
# 1. Fallback to global static key (Admin)
if api_key == settings.API_KEY:
return api_key
# 2. Check Cache first (VERY FAST)
if api_key in auth_cache:
return api_key
# 3. Check Database for Module key (Database round-trip)
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
if module:
# Save module ID to cache for next time
auth_cache[api_key] = module.id
return module
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials or API Key is inactive"
)
async def get_current_module(
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
db: Session = Depends(get_db)
):
# Use header if provided, otherwise fallback to query param
api_key = api_key_header or api_key_query
if not api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API Key missing"
)
# 1. Fallback to global static key (Admin) - No module tracking
if api_key == settings.API_KEY:
return None
# 2. Check Cache
if api_key in auth_cache:
module_id = auth_cache[api_key]
return db.query(Module).filter(Module.id == module_id).first()
# 3. DB Lookup
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
if module:
auth_cache[api_key] = module.id
return module
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials"
)