Compare commits

2 Commits
main ... tests

Author SHA1 Message Date
9325a4919a updated fallback logic 2026-02-10 20:27:12 +08:00
f6ad186178 added query parameter loading for the secret key 2026-02-10 19:48:29 +08:00
4 changed files with 55 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
from fastapi import Security, HTTPException, status, Depends from fastapi import Security, HTTPException, status, Depends
from fastapi.security.api_key import APIKeyHeader from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
from app.core.database import get_db from app.core.database import get_db
@@ -7,6 +7,7 @@ from app.models.module import Module
from cachetools import TTLCache from cachetools import TTLCache
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
# Cache keys for 5 minutes, store up to 1000 keys # Cache keys for 5 minutes, store up to 1000 keys
# This prevents a Supabase round-trip on every message # This prevents a Supabase round-trip on every message
@@ -14,27 +15,31 @@ auth_cache = TTLCache(maxsize=1000, ttl=300)
async def get_api_key( async def get_api_key(
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
if not api_key_header: # Use header if provided, otherwise fallback to query param
api_key = api_key_header or api_key_query
if not api_key:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="API Key missing" detail="API Key missing"
) )
# 1. Fallback to global static key (Admin) # 1. Fallback to global static key (Admin)
if api_key_header == settings.API_KEY: if api_key == settings.API_KEY:
return api_key_header return api_key
# 2. Check Cache first (VERY FAST) # 2. Check Cache first (VERY FAST)
if api_key_header in auth_cache: if api_key in auth_cache:
return api_key_header return api_key
# 3. Check Database for Module key (Database round-trip) # 3. Check Database for Module key (Database round-trip)
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first() module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
if module: if module:
# Save module ID to cache for next time # Save module ID to cache for next time
auth_cache[api_key_header] = module.id auth_cache[api_key] = module.id
return module return module
raise HTTPException( raise HTTPException(
@@ -44,21 +49,31 @@ async def get_api_key(
async def get_current_module( async def get_current_module(
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
# Use header if provided, otherwise fallback to query param
api_key = api_key_header or api_key_query
if not api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="API Key missing"
)
# 1. Fallback to global static key (Admin) - No module tracking # 1. Fallback to global static key (Admin) - No module tracking
if api_key_header == settings.API_KEY: if api_key == settings.API_KEY:
return None return None
# 2. Check Cache # 2. Check Cache
if api_key_header in auth_cache: if api_key in auth_cache:
module_id = auth_cache[api_key_header] module_id = auth_cache[api_key]
return db.query(Module).filter(Module.id == module_id).first() return db.query(Module).filter(Module.id == module_id).first()
# 3. DB Lookup # 3. DB Lookup
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first() module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
if module: if module:
auth_cache[api_key_header] = module.id auth_cache[api_key] = module.id
return module return module
raise HTTPException( raise HTTPException(

View File

@@ -36,11 +36,28 @@ def get_gemini_client():
@limiter.limit(settings.RATE_LIMIT) @limiter.limit(settings.RATE_LIMIT)
async def gemini_chat( async def gemini_chat(
request: Request, request: Request,
chat_data: LLMRequest,
api_key: str = Depends(get_api_key), api_key: str = Depends(get_api_key),
module: Module = Depends(get_current_module), module: Module = Depends(get_current_module),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
# Handle text/plain as JSON (fallback for CORS "Simple Requests")
content_type = request.headers.get("Content-Type", "")
if "text/plain" in content_type:
try:
body = await request.body()
import json
data = json.loads(body)
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"}
else:
# Standard JSON parsing
try:
data = await request.json()
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Invalid JSON: {str(e)}"}
client = get_gemini_client() client = get_gemini_client()
try: try:

View File

@@ -1,9 +1,11 @@
services: services:
api: api:
build: . build: .
container_name: storyline-ai-gateway container_name: ai-gateway
networks:
- caddy_network
ports: ports:
- "8191:8000" - "8000:8000"
env_file: env_file:
- .env - .env
restart: always restart: always
@@ -11,3 +13,8 @@ services:
- .:/app - .:/app
# Override command for development/auto-reload if needed # Override command for development/auto-reload if needed
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
networks:
caddy_network:
# Define the network at the bottom
external: true

BIN
server_log.txt Normal file

Binary file not shown.