from fastapi import APIRouter, Depends, Request from app.api.deps import get_api_key, get_current_module from app.models.module import Module from sqlalchemy.orm import Session from app.core.database import get_db from fastapi.responses import PlainTextResponse from app.core.limiter import limiter from app.core.config import settings from pydantic import BaseModel from google import genai import asyncio from google.genai import types from app.core.prompts import GEMINI_SYSTEM_PROMPT router = APIRouter() class LLMRequest(BaseModel): prompt: str context: str = "" system_prompt: str | None = None knowledge_base: str | None = None temperature: float = 0.7 top_p: float = 0.95 top_k: int = 40 max_output_tokens: int = 8192 # Shared client instance (global) _client = None def get_gemini_client(): global _client if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key": _client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'}) return _client @router.api_route("/chat", methods=["GET", "POST"]) @limiter.limit(settings.RATE_LIMIT) async def gemini_chat( request: Request, api_key: str = Depends(get_api_key), module: Module = Depends(get_current_module), db: Session = Depends(get_db) ): chat_data = None if request.method == "GET": # Handle GET requests (Ultimate Simple Request) params = request.query_params if not params.get("prompt"): return {"status": "error", "detail": "Missing 'prompt' query parameter"} chat_data = LLMRequest( prompt=params.get("prompt"), context=params.get("context", ""), system_prompt=params.get("system_prompt"), knowledge_base=params.get("knowledge_base"), temperature=float(params.get("temperature", 0.7)), top_p=float(params.get("top_p", 0.95)), top_k=int(params.get("top_k", 40)), max_output_tokens=int(params.get("max_output_tokens", 8192)) ) else: # Handle POST requests content_type = request.headers.get("Content-Type", "") if "text/plain" in content_type: try: body = await request.body() import json data = json.loads(body) chat_data = LLMRequest(**data) except Exception as e: return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"} else: # Standard JSON parsing try: data = await request.json() chat_data = LLMRequest(**data) except Exception as e: return {"status": "error", "detail": f"Invalid JSON: {str(e)}"} if not chat_data: return {"status": "error", "detail": "Could not determine request data"} client = get_gemini_client() try: if not client: # Mock response response_text = f"MOCK: Gemini response to '{chat_data.prompt}'" if module: # Estimate tokens for mock prompt_tokens = len(chat_data.prompt) // 4 completion_tokens = len(response_text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() return { "status": "mock", "model": "gemini", "response": response_text } prompt_content = chat_data.prompt if chat_data.context: prompt_content = f"Context: {chat_data.context}\n\nPrompt: {chat_data.prompt}" # Prepare system instruction system_instruction = chat_data.system_prompt or GEMINI_SYSTEM_PROMPT if chat_data.knowledge_base: system_instruction += f"\n\nKnowledge Base:\n{chat_data.knowledge_base}" # Using the async generation method provided by the new google-genai library # We use await to ensure we don't block the event loop response = await client.aio.models.generate_content( model="gemini-2.5-flash", contents=prompt_content, config=types.GenerateContentConfig( system_instruction=system_instruction, temperature=chat_data.temperature, top_p=chat_data.top_p, top_k=chat_data.top_k, max_output_tokens=chat_data.max_output_tokens ) ) # Track usage if valid module if module: # Estimate tokens since metadata might vary # 1 char ~= 0.25 tokens (rough estimate if exact count not returned) # Gemini response usually has usage_metadata usage = response.usage_metadata prompt_tokens = usage.prompt_token_count if usage else len(prompt_content) // 4 completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4 module.ingress_tokens += prompt_tokens module.egress_tokens += completion_tokens module.total_tokens += (prompt_tokens + completion_tokens) db.commit() except Exception as e: return {"status": "error", "detail": str(e)} # Final Response return { "status": "success", "model": "gemini", "response": response.text }