85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
from fastapi import APIRouter, Depends, Request
|
|
from app.api.deps import get_api_key, get_current_module
|
|
from app.models.module import Module
|
|
from sqlalchemy.orm import Session
|
|
from app.core.database import get_db
|
|
from app.core.limiter import limiter
|
|
from app.core.config import settings
|
|
from pydantic import BaseModel
|
|
from google import genai
|
|
import asyncio
|
|
|
|
router = APIRouter()
|
|
|
|
class LLMRequest(BaseModel):
|
|
prompt: str
|
|
context: str = ""
|
|
|
|
# Shared client instance (global)
|
|
_client = None
|
|
|
|
def get_gemini_client():
|
|
global _client
|
|
if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key":
|
|
_client = genai.Client(api_key=settings.GOOGLE_API_KEY, http_options={'api_version': 'v1alpha'})
|
|
return _client
|
|
|
|
@router.post("/chat")
|
|
@limiter.limit(settings.RATE_LIMIT)
|
|
async def gemini_chat(
|
|
request: Request,
|
|
chat_data: LLMRequest,
|
|
api_key: str = Depends(get_api_key),
|
|
module: Module = Depends(get_current_module),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
client = get_gemini_client()
|
|
|
|
try:
|
|
if not client:
|
|
# Mock response
|
|
response_text = f"MOCK: Gemini response to '{chat_data.prompt}'"
|
|
if module:
|
|
# Estimate tokens for mock
|
|
prompt_tokens = len(chat_data.prompt) // 4
|
|
completion_tokens = len(response_text) // 4
|
|
module.ingress_tokens += prompt_tokens
|
|
module.egress_tokens += completion_tokens
|
|
module.total_tokens += (prompt_tokens + completion_tokens)
|
|
db.commit()
|
|
|
|
return {
|
|
"status": "mock",
|
|
"model": "gemini",
|
|
"response": response_text
|
|
}
|
|
|
|
# Using the async generation method provided by the new google-genai library
|
|
# We use await to ensure we don't block the event loop
|
|
response = await client.aio.models.generate_content(
|
|
model="gemini-2.0-flash",
|
|
contents=chat_data.prompt
|
|
)
|
|
|
|
# Track usage if valid module
|
|
if module:
|
|
# Estimate tokens since metadata might vary
|
|
# 1 char ~= 0.25 tokens (rough estimate if exact count not returned)
|
|
# Gemini response usually has usage_metadata
|
|
usage = response.usage_metadata
|
|
prompt_tokens = usage.prompt_token_count if usage else len(chat_data.prompt) // 4
|
|
completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4
|
|
|
|
module.ingress_tokens += prompt_tokens
|
|
module.egress_tokens += completion_tokens
|
|
module.total_tokens += (prompt_tokens + completion_tokens)
|
|
db.commit()
|
|
|
|
return {
|
|
"status": "success",
|
|
"model": "gemini",
|
|
"response": response.text
|
|
}
|
|
except Exception as e:
|
|
return {"status": "error", "detail": str(e)}
|