from fastapi import APIRouter, Depends, Request from app.api.deps import get_api_key, get_current_module from app.models.module import Module from sqlalchemy.orm import Session from app.core.database import get_db from app.core.limiter import limiter from app.core.config import settings from pydantic import BaseModel from google import genai import asyncio router = APIRouter() class LLMRequest(BaseModel): prompt: str context: str = "" # Shared client instance (global) _client = None def get_gemini_client(): global _client if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key": _client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'}) return _client @router.post("/chat") @limiter.limit(settings.RATE_LIMIT) async def gemini_chat( request: Request, chat_data: LLMRequest, api_key: str = Depends(get_api_key), module: Module = Depends(get_current_module), db: Session = Depends(get_db) ): client = get_gemini_client() try: if not client: # Mock response response_text = f"MOCK: Gemini response to '{chat_data.prompt}'" if module: # Estimate tokens for mock prompt_tokens = len(chat_data.prompt) // 4 completion_tokens = len(response_text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() return { "status": "mock", "model": "gemini", "response": response_text } # Using the async generation method provided by the new google-genai library # We use await to ensure we don't block the event loop response = await client.aio.models.generate_content( model="gemini-2.0-flash", contents=chat_data.prompt ) # Track usage if valid module if module: # Estimate tokens since metadata might vary # 1 char ~= 0.25 tokens (rough estimate if exact count not returned) # Gemini response usually has usage_metadata usage = response.usage_metadata prompt_tokens = usage.prompt_token_count if usage else len(chat_data.prompt) // 4 completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() return { "status": "success", "model": "gemini", "response": response.text } except Exception as e: return {"status": "error", "detail": str(e)}