Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9325a4919a | |||
| f6ad186178 |
@@ -1,5 +1,5 @@
|
|||||||
from fastapi import Security, HTTPException, status, Depends
|
from fastapi import Security, HTTPException, status, Depends
|
||||||
from fastapi.security.api_key import APIKeyHeader
|
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
@@ -7,6 +7,7 @@ from app.models.module import Module
|
|||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
|
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||||
|
|
||||||
# Cache keys for 5 minutes, store up to 1000 keys
|
# Cache keys for 5 minutes, store up to 1000 keys
|
||||||
# This prevents a Supabase round-trip on every message
|
# This prevents a Supabase round-trip on every message
|
||||||
@@ -14,27 +15,31 @@ auth_cache = TTLCache(maxsize=1000, ttl=300)
|
|||||||
|
|
||||||
async def get_api_key(
|
async def get_api_key(
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
|
api_key_query: str = Security(api_key_query),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
if not api_key_header:
|
# Use header if provided, otherwise fallback to query param
|
||||||
|
api_key = api_key_header or api_key_query
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="API Key missing"
|
detail="API Key missing"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Fallback to global static key (Admin)
|
# 1. Fallback to global static key (Admin)
|
||||||
if api_key_header == settings.API_KEY:
|
if api_key == settings.API_KEY:
|
||||||
return api_key_header
|
return api_key
|
||||||
|
|
||||||
# 2. Check Cache first (VERY FAST)
|
# 2. Check Cache first (VERY FAST)
|
||||||
if api_key_header in auth_cache:
|
if api_key in auth_cache:
|
||||||
return api_key_header
|
return api_key
|
||||||
|
|
||||||
# 3. Check Database for Module key (Database round-trip)
|
# 3. Check Database for Module key (Database round-trip)
|
||||||
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first()
|
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
|
||||||
if module:
|
if module:
|
||||||
# Save module ID to cache for next time
|
# Save module ID to cache for next time
|
||||||
auth_cache[api_key_header] = module.id
|
auth_cache[api_key] = module.id
|
||||||
return module
|
return module
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -44,21 +49,31 @@ async def get_api_key(
|
|||||||
|
|
||||||
async def get_current_module(
|
async def get_current_module(
|
||||||
api_key_header: str = Security(api_key_header),
|
api_key_header: str = Security(api_key_header),
|
||||||
|
api_key_query: str = Security(api_key_query),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
|
# Use header if provided, otherwise fallback to query param
|
||||||
|
api_key = api_key_header or api_key_query
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="API Key missing"
|
||||||
|
)
|
||||||
|
|
||||||
# 1. Fallback to global static key (Admin) - No module tracking
|
# 1. Fallback to global static key (Admin) - No module tracking
|
||||||
if api_key_header == settings.API_KEY:
|
if api_key == settings.API_KEY:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. Check Cache
|
# 2. Check Cache
|
||||||
if api_key_header in auth_cache:
|
if api_key in auth_cache:
|
||||||
module_id = auth_cache[api_key_header]
|
module_id = auth_cache[api_key]
|
||||||
return db.query(Module).filter(Module.id == module_id).first()
|
return db.query(Module).filter(Module.id == module_id).first()
|
||||||
|
|
||||||
# 3. DB Lookup
|
# 3. DB Lookup
|
||||||
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first()
|
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
|
||||||
if module:
|
if module:
|
||||||
auth_cache[api_key_header] = module.id
|
auth_cache[api_key] = module.id
|
||||||
return module
|
return module
|
||||||
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
@@ -36,11 +36,28 @@ def get_gemini_client():
|
|||||||
@limiter.limit(settings.RATE_LIMIT)
|
@limiter.limit(settings.RATE_LIMIT)
|
||||||
async def gemini_chat(
|
async def gemini_chat(
|
||||||
request: Request,
|
request: Request,
|
||||||
chat_data: LLMRequest,
|
|
||||||
api_key: str = Depends(get_api_key),
|
api_key: str = Depends(get_api_key),
|
||||||
module: Module = Depends(get_current_module),
|
module: Module = Depends(get_current_module),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
|
# Handle text/plain as JSON (fallback for CORS "Simple Requests")
|
||||||
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
if "text/plain" in content_type:
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
import json
|
||||||
|
data = json.loads(body)
|
||||||
|
chat_data = LLMRequest(**data)
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"}
|
||||||
|
else:
|
||||||
|
# Standard JSON parsing
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
chat_data = LLMRequest(**data)
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "detail": f"Invalid JSON: {str(e)}"}
|
||||||
|
|
||||||
client = get_gemini_client()
|
client = get_gemini_client()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
services:
|
services:
|
||||||
api:
|
api:
|
||||||
build: .
|
build: .
|
||||||
container_name: storyline-ai-gateway
|
container_name: ai-gateway
|
||||||
|
networks:
|
||||||
|
- caddy_network
|
||||||
ports:
|
ports:
|
||||||
- "8191:8000"
|
- "8000:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
restart: always
|
restart: always
|
||||||
@@ -11,3 +13,8 @@ services:
|
|||||||
- .:/app
|
- .:/app
|
||||||
# Override command for development/auto-reload if needed
|
# Override command for development/auto-reload if needed
|
||||||
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
||||||
|
networks:
|
||||||
|
caddy_network:
|
||||||
|
# Define the network at the bottom
|
||||||
|
external: true
|
||||||
|
|||||||
BIN
server_log.txt
Normal file
BIN
server_log.txt
Normal file
Binary file not shown.
Reference in New Issue
Block a user