Files
ai-gateway/app/api/endpoints/gemini.py
2026-02-10 21:21:42 +08:00

148 lines
5.4 KiB
Python

from fastapi import APIRouter, Depends, Request
from app.api.deps import get_api_key, get_current_module
from app.models.module import Module
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.limiter import limiter
from app.core.config import settings
from pydantic import BaseModel
from google import genai
import asyncio
from google.genai import types
from app.core.prompts import GEMINI_SYSTEM_PROMPT
router = APIRouter()
class LLMRequest(BaseModel):
prompt: str
context: str = ""
system_prompt: str | None = None
knowledge_base: str | None = None
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 40
max_output_tokens: int = 8192
# Shared client instance (global)
_client = None
def get_gemini_client():
global _client
if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key":
_client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'})
return _client
@router.api_route("/chat", methods=["GET", "POST"])
@limiter.limit(settings.RATE_LIMIT)
async def gemini_chat(
request: Request,
api_key: str = Depends(get_api_key),
module: Module = Depends(get_current_module),
db: Session = Depends(get_db)
):
chat_data = None
if request.method == "GET":
# Handle GET requests (Ultimate Simple Request)
params = request.query_params
if not params.get("prompt"):
return {"status": "error", "detail": "Missing 'prompt' query parameter"}
chat_data = LLMRequest(
prompt=params.get("prompt"),
context=params.get("context", ""),
system_prompt=params.get("system_prompt"),
knowledge_base=params.get("knowledge_base"),
temperature=float(params.get("temperature", 0.7)),
top_p=float(params.get("top_p", 0.95)),
top_k=int(params.get("top_k", 40)),
max_output_tokens=int(params.get("max_output_tokens", 8192))
)
else:
# Handle POST requests
content_type = request.headers.get("Content-Type", "")
if "text/plain" in content_type:
try:
body = await request.body()
import json
data = json.loads(body)
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Failed to parse text/plain as JSON: {str(e)}"}
else:
# Standard JSON parsing
try:
data = await request.json()
chat_data = LLMRequest(**data)
except Exception as e:
return {"status": "error", "detail": f"Invalid JSON: {str(e)}"}
if not chat_data:
return {"status": "error", "detail": "Could not determine request data"}
client = get_gemini_client()
try:
if not client:
# Mock response
response_text = f"MOCK: Gemini response to '{chat_data.prompt}'"
if module:
# Estimate tokens for mock
prompt_tokens = len(chat_data.prompt) // 4
completion_tokens = len(response_text) // 4
module.ingress_tokens += prompt_tokens
module.egress_tokens += completion_tokens
module.total_tokens += (prompt_tokens + completion_tokens)
db.commit()
return {
"status": "mock",
"model": "gemini",
"response": response_text
}
prompt_content = chat_data.prompt
if chat_data.context:
prompt_content = f"Context: {chat_data.context}\n\nPrompt: {chat_data.prompt}"
# Prepare system instruction
system_instruction = chat_data.system_prompt or GEMINI_SYSTEM_PROMPT
if chat_data.knowledge_base:
system_instruction += f"\n\nKnowledge Base:\n{chat_data.knowledge_base}"
# Using the async generation method provided by the new google-genai library
# We use await to ensure we don't block the event loop
response = await client.aio.models.generate_content(
model="gemini-2.5-flash",
contents=prompt_content,
config=types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=chat_data.temperature,
top_p=chat_data.top_p,
top_k=chat_data.top_k,
max_output_tokens=chat_data.max_output_tokens
)
)
# Track usage if valid module
if module:
# Estimate tokens since metadata might vary
# 1 char ~= 0.25 tokens (rough estimate if exact count not returned)
# Gemini response usually has usage_metadata
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else len(prompt_content) // 4
completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4
module.ingress_tokens += prompt_tokens
module.egress_tokens += completion_tokens
module.total_tokens += (prompt_tokens + completion_tokens)
db.commit()
return {
"status": "success",
"model": "gemini",
"response": response.text
}
except Exception as e:
return {"status": "error", "detail": str(e)}