Tweaked CORS
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from fastapi import Security, HTTPException, status, Depends
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
from fastapi.security.api_key import APIKeyHeader, APIKeyQuery
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
@@ -7,34 +7,38 @@ from app.models.module import Module
|
||||
from cachetools import TTLCache
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
|
||||
# Cache keys for 5 minutes, store up to 1000 keys
|
||||
# This prevents a Supabase round-trip on every message
|
||||
auth_cache = TTLCache(maxsize=1000, ttl=300)
|
||||
|
||||
async def get_api_key(
|
||||
api_key_header: str = Security(api_key_header),
|
||||
api_key_h: str = Security(api_key_header),
|
||||
api_key_q: str = Security(api_key_query),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
if not api_key_header:
|
||||
# Use header if provided, otherwise fallback to query param
|
||||
api_key = api_key_h or api_key_q
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="API Key missing"
|
||||
)
|
||||
|
||||
# 1. Fallback to global static key (Admin)
|
||||
if api_key_header == settings.API_KEY:
|
||||
return api_key_header
|
||||
if api_key == settings.API_KEY:
|
||||
return api_key
|
||||
|
||||
# 2. Check Cache first (VERY FAST)
|
||||
if api_key_header in auth_cache:
|
||||
return api_key_header
|
||||
if api_key in auth_cache:
|
||||
return api_key
|
||||
|
||||
# 3. Check Database for Module key (Database round-trip)
|
||||
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first()
|
||||
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
|
||||
if module:
|
||||
# Save module ID to cache for next time
|
||||
auth_cache[api_key_header] = module.id
|
||||
# Save key to cache for next time
|
||||
auth_cache[api_key] = module.id
|
||||
return module
|
||||
|
||||
raise HTTPException(
|
||||
@@ -43,22 +47,31 @@ async def get_api_key(
|
||||
)
|
||||
|
||||
async def get_current_module(
|
||||
api_key_header: str = Security(api_key_header),
|
||||
api_key_h: str = Security(api_key_header),
|
||||
api_key_q: str = Security(api_key_query),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
api_key = api_key_h or api_key_q
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="API Key missing"
|
||||
)
|
||||
|
||||
# 1. Fallback to global static key (Admin) - No module tracking
|
||||
if api_key_header == settings.API_KEY:
|
||||
if api_key == settings.API_KEY:
|
||||
return None
|
||||
|
||||
# 2. Check Cache
|
||||
if api_key_header in auth_cache:
|
||||
module_id = auth_cache[api_key_header]
|
||||
if api_key in auth_cache:
|
||||
module_id = auth_cache[api_key]
|
||||
return db.query(Module).filter(Module.id == module_id).first()
|
||||
|
||||
# 3. DB Lookup
|
||||
module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first()
|
||||
module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first()
|
||||
if module:
|
||||
auth_cache[api_key_header] = module.id
|
||||
auth_cache[api_key] = module.id
|
||||
return module
|
||||
|
||||
raise HTTPException(
|
||||
|
||||
@@ -32,15 +32,54 @@ def get_gemini_client():
|
||||
_client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'})
|
||||
return _client
|
||||
|
||||
@router.post("/chat")
|
||||
@router.api_route("/chat", methods=["GET", "POST"])
|
||||
@limiter.limit(settings.RATE_LIMIT)
|
||||
async def gemini_chat(
|
||||
request: Request,
|
||||
chat_data: LLMRequest,
|
||||
api_key: str = Depends(get_api_key),
|
||||
module: Module = Depends(get_current_module),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
chat_data = None
|
||||
|
||||
if request.method == "GET":
|
||||
# Handle GET requests (Ultimate Simple Request)
|
||||
params = request.query_params
|
||||
if not params.get("prompt"):
|
||||
return {"status": "error", "detail": "Missing 'prompt' query parameter"}
|
||||
|
||||
chat_data = LLMRequest(
|
||||
prompt=params.get("prompt"),
|
||||
context=params.get("context", ""),
|
||||
system_prompt=params.get("system_prompt"),
|
||||
knowledge_base=params.get("knowledge_base"),
|
||||
temperature=float(params.get("temperature", 0.7)),
|
||||
top_p=float(params.get("top_p", 0.95)),
|
||||
top_k=int(params.get("top_k", 40)),
|
||||
max_output_tokens=int(params.get("max_output_tokens", 8192))
|
||||
)
|
||||
else:
|
||||
# Handle POST requests
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "text/plain" in content_type:
|
||||
try:
|
||||
body = await request.body()
|
||||
import json
|
||||
data = json.loads(body)
|
||||
chat_data = LLMRequest(**data)
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"}
|
||||
else:
|
||||
# Standard JSON parsing
|
||||
try:
|
||||
data = await request.json()
|
||||
chat_data = LLMRequest(**data)
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": f"Invalid JSON: {str(e)}"}
|
||||
|
||||
if not chat_data:
|
||||
return {"status": "error", "detail": "Could not determine request data"}
|
||||
|
||||
client = get_gemini_client()
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user