From b15e39d01138a0f558a503cafb19258c86f0eeb3 Mon Sep 17 00:00:00 2001 From: Paulo Reyes Date: Tue, 10 Feb 2026 21:21:42 +0800 Subject: [PATCH] Tweaked CORS --- app/api/deps.py | 47 +++++++++++++++++++++++-------------- app/api/endpoints/gemini.py | 43 +++++++++++++++++++++++++++++++-- app/main.py | 13 ++++++++-- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/app/api/deps.py b/app/api/deps.py index a24cee6..7d2382b 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,5 +1,5 @@ from fastapi import Security, HTTPException, status, Depends -from fastapi.security.api_key import APIKeyHeader +from fastapi.security.api_key import APIKeyHeader, APIKeyQuery from sqlalchemy.orm import Session from app.core.config import settings from app.core.database import get_db @@ -7,34 +7,38 @@ from app.models.module import Module from cachetools import TTLCache api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) +api_key_query = APIKeyQuery(name="api_key", auto_error=False) # Cache keys for 5 minutes, store up to 1000 keys -# This prevents a Supabase round-trip on every message auth_cache = TTLCache(maxsize=1000, ttl=300) async def get_api_key( - api_key_header: str = Security(api_key_header), + api_key_h: str = Security(api_key_header), + api_key_q: str = Security(api_key_query), db: Session = Depends(get_db) ): - if not api_key_header: + # Use header if provided, otherwise fallback to query param + api_key = api_key_h or api_key_q + + if not api_key: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="API Key missing" ) # 1. Fallback to global static key (Admin) - if api_key_header == settings.API_KEY: - return api_key_header + if api_key == settings.API_KEY: + return api_key # 2. Check Cache first (VERY FAST) - if api_key_header in auth_cache: - return api_key_header + if api_key in auth_cache: + return api_key # 3. Check Database for Module key (Database round-trip) - module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first() + module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first() if module: - # Save module ID to cache for next time - auth_cache[api_key_header] = module.id + # Save key to cache for next time + auth_cache[api_key] = module.id return module raise HTTPException( @@ -43,22 +47,31 @@ async def get_api_key( ) async def get_current_module( - api_key_header: str = Security(api_key_header), + api_key_h: str = Security(api_key_header), + api_key_q: str = Security(api_key_query), db: Session = Depends(get_db) ): + api_key = api_key_h or api_key_q + + if not api_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="API Key missing" + ) + # 1. Fallback to global static key (Admin) - No module tracking - if api_key_header == settings.API_KEY: + if api_key == settings.API_KEY: return None # 2. Check Cache - if api_key_header in auth_cache: - module_id = auth_cache[api_key_header] + if api_key in auth_cache: + module_id = auth_cache[api_key] return db.query(Module).filter(Module.id == module_id).first() # 3. DB Lookup - module = db.query(Module).filter(Module.secret_key == api_key_header, Module.is_active == True).first() + module = db.query(Module).filter(Module.secret_key == api_key, Module.is_active == True).first() if module: - auth_cache[api_key_header] = module.id + auth_cache[api_key] = module.id return module raise HTTPException( diff --git a/app/api/endpoints/gemini.py b/app/api/endpoints/gemini.py index e0dfccb..204e569 100644 --- a/app/api/endpoints/gemini.py +++ b/app/api/endpoints/gemini.py @@ -32,15 +32,54 @@ def get_gemini_client(): _client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'}) return _client -@router.post("/chat") +@router.api_route("/chat", methods=["GET", "POST"]) @limiter.limit(settings.RATE_LIMIT) async def gemini_chat( request: Request, - chat_data: LLMRequest, api_key: str = Depends(get_api_key), module: Module = Depends(get_current_module), db: Session = Depends(get_db) ): + chat_data = None + + if request.method == "GET": + # Handle GET requests (Ultimate Simple Request) + params = request.query_params + if not params.get("prompt"): + return {"status": "error", "detail": "Missing 'prompt' query parameter"} + + chat_data = LLMRequest( + prompt=params.get("prompt"), + context=params.get("context", ""), + system_prompt=params.get("system_prompt"), + knowledge_base=params.get("knowledge_base"), + temperature=float(params.get("temperature", 0.7)), + top_p=float(params.get("top_p", 0.95)), + top_k=int(params.get("top_k", 40)), + max_output_tokens=int(params.get("max_output_tokens", 8192)) + ) + else: + # Handle POST requests + content_type = request.headers.get("Content-Type", "") + if "text/plain" in content_type: + try: + body = await request.body() + import json + data = json.loads(body) + chat_data = LLMRequest(**data) + except Exception as e: + return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"} + else: + # Standard JSON parsing + try: + data = await request.json() + chat_data = LLMRequest(**data) + except Exception as e: + return {"status": "error", "detail": f"Invalid JSON: {str(e)}"} + + if not chat_data: + return {"status": "error", "detail": "Could not determine request data"} + client = get_gemini_client() try: diff --git a/app/main.py b/app/main.py index b62ddb6..97c9b8c 100644 --- a/app/main.py +++ b/app/main.py @@ -32,12 +32,21 @@ def create_application() -> FastAPI: application.mount("/static", StaticFiles(directory="app/static"), name="static") # Set up CORS + origins = [ + "https://articulateusercontent.com", + "https://ai-gateway.ldex.dev", + "http://localhost:8000", + "http://127.0.0.1:8000", + ] + application.add_middleware( CORSMiddleware, - allow_origins=["*"], - allow_credentials=False, # Changed to False for better compat with allow_origins=["*"] + allow_origins=origins, + allow_origin_regex=r"https://.*\.articulateusercontent\.com", + allow_credentials=True, allow_methods=["*"], allow_headers=["*"], + expose_headers=["*"], ) # Set up Rate Limiter