from fastapi import APIRouter, Depends, Request from app.api.deps import get_api_key, get_current_module from app.models.module import Module from sqlalchemy.orm import Session from app.core.database import get_db from app.core.limiter import limiter from app.core.config import settings from pydantic import BaseModel from google import genai import asyncio from google.genai import types from app.core.prompts import GEMINI_SYSTEM_PROMPT router = APIRouter() class LLMRequest(BaseModel): prompt: str context: str = "" system_prompt: str | None = None knowledge_base: str | None = None temperature: float = 0.7 top_p: float = 0.95 top_k: int = 40 max_output_tokens: int = 8192 # Shared client instance (global) _client = None def get_gemini_client(): global _client if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key": _client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'}) return _client @router.api_route("/chat", methods=["GET", "POST"]) @limiter.limit(settings.RATE_LIMIT) async def gemini_chat( request: Request, api_key: str = Depends(get_api_key), module: Module = Depends(get_current_module), db: Session = Depends(get_db) ): chat_data = None if request.method == "GET": # Handle GET requests (Ultimate Simple Request) params = request.query_params if not params.get("prompt"): return {"status": "error", "detail": "Missing 'prompt' query parameter"} chat_data = LLMRequest( prompt=params.get("prompt"), context=params.get("context", ""), system_prompt=params.get("system_prompt"), knowledge_base=params.get("knowledge_base"), temperature=float(params.get("temperature", 0.7)), top_p=float(params.get("top_p", 0.95)), top_k=int(params.get("top_k", 40)), max_output_tokens=int(params.get("max_output_tokens", 8192)) ) else: # Handle POST requests content_type = request.headers.get("Content-Type", "") if "text/plain" in content_type: try: body = await request.body() import json data = json.loads(body) chat_data = LLMRequest(**data) except Exception as e: return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"} else: # Standard JSON parsing try: data = await request.json() chat_data = LLMRequest(**data) except Exception as e: return {"status": "error", "detail": f"Invalid JSON: {str(e)}"} if not chat_data: return {"status": "error", "detail": "Could not determine request data"} client = get_gemini_client() try: if not client: # Mock response response_text = f"MOCK: Gemini response to '{chat_data.prompt}'" if module: # Estimate tokens for mock prompt_tokens = len(chat_data.prompt) // 4 completion_tokens = len(response_text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() return { "status": "mock", "model": "gemini", "response": response_text } prompt_content = chat_data.prompt if chat_data.context: prompt_content = f"Context: {chat_data.context}\n\nPrompt: {chat_data.prompt}" # Prepare system instruction system_instruction = chat_data.system_prompt or GEMINI_SYSTEM_PROMPT if chat_data.knowledge_base: system_instruction += f"\n\nKnowledge Base:\n{chat_data.knowledge_base}" # Using the async generation method provided by the new google-genai library # We use await to ensure we don't block the event loop response = await client.aio.models.generate_content( model="gemini-2.5-flash", contents=prompt_content, config=types.GenerateContentConfig( system_instruction=system_instruction, temperature=chat_data.temperature, top_p=chat_data.top_p, top_k=chat_data.top_k, max_output_tokens=chat_data.max_output_tokens ) ) # Track usage if valid module if module: # Estimate tokens since metadata might vary # 1 char ~= 0.25 tokens (rough estimate if exact count not returned) # Gemini response usually has usage_metadata usage = response.usage_metadata prompt_tokens = usage.prompt_token_count if usage else len(prompt_content) // 4 completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() return { "status": "success", "model": "gemini", "response": response.text } except Exception as e: return {"status": "error", "detail": str(e)}