Files
ai-gateway/app/api/endpoints/gemini.py

153 lines
5.5 KiB
Python

from fastapi import APIRouter, Depends, Request
from app.api.deps import get_api_key, get_current_module
from app.models.module import Module
from sqlalchemy.orm import Session
from app.core.database import get_db
from fastapi.responses import PlainTextResponse
from app.core.limiter import limiter
from app.core.config import settings
from pydantic import BaseModel
from google import genai
import asyncio
from google.genai import types
from app.core.prompts import GEMINI_SYSTEM_PROMPT
router = APIRouter()
class LLMRequest(BaseModel):
prompt: str
context: str = ""
system_prompt: str | None = None
knowledge_base: str | None = None
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 40
max_output_tokens: int = 8192
# Shared client instance (global)
_client = None
def get_gemini_client():
global _client
if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key":
_client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'})
return _client
@router.api_route("/chat", methods=["GET", "POST"])
@limiter.limit(settings.RATE_LIMIT)
async def gemini_chat(
request: Request,
api_key: str = Depends(get_api_key),
module: Module = Depends(get_current_module),
db: Session = Depends(get_db)
):
chat_data = None
if request.method == "GET":
# Handle GET requests (Ultimate Simple Request)
params = request.query_params
if not params.get("prompt"):
return {"status": "error", "detail": "Missing 'prompt' query parameter"}
chat_data = LLMRequest(
prompt=params.get("prompt"),
context=params.get("context", ""),
system_prompt=params.get("system_prompt"),
knowledge_base=params.get("knowledge_base"),
temperature=float(params.get("temperature", 0.7)),
top_p=float(params.get("top_p", 0.95)),
top_k=int(params.get("top_k", 40)),
max_output_tokens=int(params.get("max_output_tokens", 8192))
)
else:
# Handle POST requests
content_type = request.headers.get("Content-Type", "")
if "text/plain" in content_type:
try:
body = await request.body()
import json
data = json.loads(body)
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"}
else:
# Standard JSON parsing
try:
data = await request.json()
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Invalid JSON: {str(e)}"}
if not chat_data:
return {"status": "error", "detail": "Could not determine request data"}
client = get_gemini_client()
try:
if not client:
# Mock response
response_text = f"MOCK: Gemini response to '{chat_data.prompt}'"
if module:
# Estimate tokens for mock
prompt_tokens = len(chat_data.prompt) // 4
completion_tokens = len(response_text) // 4
module.ingress_tokens += prompt_tokens
module.egress_tokens += completion_tokens
module.total_tokens += (prompt_tokens + completion_tokens)
db.commit()
return {
"status": "mock",
"model": "gemini",
"response": response_text
}
prompt_content = chat_data.prompt
if chat_data.context:
prompt_content = f"Context: {chat_data.context}\n\nPrompt: {chat_data.prompt}"
# Prepare system instruction
system_instruction = chat_data.system_prompt or GEMINI_SYSTEM_PROMPT
if chat_data.knowledge_base:
system_instruction += f"\n\nKnowledge Base:\n{chat_data.knowledge_base}"
# Using the async generation method provided by the new google-genai library
# We use await to ensure we don't block the event loop
response = await client.aio.models.generate_content(
model="gemini-2.5-flash",
contents=prompt_content,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=chat_data.temperature,
top_p=chat_data.top_p,
top_k=chat_data.top_k,
max_output_tokens=chat_data.max_output_tokens
)
)
# Track usage if valid module
if module:
# Estimate tokens since metadata might vary
# 1 char ~= 0.25 tokens (rough estimate if exact count not returned)
# Gemini response usually has usage_metadata
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else len(prompt_content) // 4
completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4
module.ingress_tokens += prompt_tokens
module.egress_tokens += completion_tokens
module.total_tokens += (prompt_tokens + completion_tokens)
db.commit()
except Exception as e:
return {"status": "error", "detail": str(e)}
# Final Response
return {
"status": "success",
"model": "gemini",
"response": response.text
}