Files
ai-gateway/app/api/endpoints/gemini_score.py
T

168 lines
5.2 KiB
Python

from fastapi import APIRouter, Depends, Request, HTTPException
from app.api.deps import get_current_module
from app.models.module import Module
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.limiter import limiter
from app.core.config import settings
from pydantic import BaseModel
from typing import List
from google import genai
from google.genai import types
import json
router = APIRouter()
_client = None
def get_gemini_client():
global _client
if _client is None and settings.GOOGLE_API_KEY and settings.GOOGLE_API_KEY != "your-google-api-key":
_client = genai.Client(
api_key=settings.GOOGLE_API_KEY,
http_options={"api_version": "v1alpha"},
)
return _client
class TranscriptEntry(BaseModel):
role: str # "user" or "assistant"
text: str
class ScorecardCriterion(BaseModel):
name: str
weight: float
description: str
good_example: str | None = None
poor_example: str | None = None
class Scorecard(BaseModel):
criteria: List[ScorecardCriterion]
class ScoreRequest(BaseModel):
transcript: List[TranscriptEntry]
scorecard: Scorecard
pass_threshold: int = 70
SCORE_RESPONSE_SCHEMA = {
"type": "OBJECT",
"properties": {
"overall_score": {"type": "NUMBER"},
"passed": {"type": "BOOLEAN"},
"criteria_scores": {
"type": "ARRAY",
"items": {
"type": "OBJECT",
"properties": {
"name": {"type": "STRING"},
"score": {"type": "NUMBER"},
"feedback": {"type": "STRING"},
},
},
},
"positives": {"type": "ARRAY", "items": {"type": "STRING"}},
"improvements": {"type": "ARRAY", "items": {"type": "STRING"}},
},
"required": ["overall_score", "passed", "criteria_scores", "positives", "improvements"],
}
def _build_prompt(body: ScoreRequest) -> str:
transcript_text = "\n".join(
f"{e.role.upper()}: {e.text}" for e in body.transcript
)
criteria_text = "\n".join(
f"- {c.name} (weight: {c.weight}%): {c.description}"
+ (f"\n Good example: {c.good_example}" if c.good_example else "")
+ (f"\n Poor example: {c.poor_example}" if c.poor_example else "")
for c in body.scorecard.criteria
)
return f"""You are an expert conversation evaluator for workplace learning simulations.
Score the following conversation transcript against the provided scorecard criteria.
TRANSCRIPT:
{transcript_text}
SCORING CRITERIA (weights must sum to 100%):
{criteria_text}
Instructions:
- Score each criterion from 0 to 100 based on evidence in the transcript.
- Calculate overall_score as the weighted average of all criteria scores.
- Set passed to true if overall_score >= {body.pass_threshold}.
- Write specific, evidence-based feedback for each criterion (1-2 sentences).
- List exactly 2-3 positives (specific things done well with transcript evidence).
- List exactly 2-3 improvements (specific, actionable suggestions).
Return a single JSON object following the response schema."""
@router.post("/score")
@limiter.limit(settings.RATE_LIMIT)
async def score_conversation(
request: Request,
body: ScoreRequest,
module: Module = Depends(get_current_module),
db: Session = Depends(get_db),
):
client = get_gemini_client()
if not client:
mock_result = {
"overall_score": 75.0,
"passed": True,
"criteria_scores": [
{
"name": c.name,
"score": 75.0,
"feedback": f"Mock feedback for {c.name}.",
}
for c in body.scorecard.criteria
],
"positives": [
"Maintained a professional tone throughout.",
"Responded clearly to the main questions.",
],
"improvements": [
"Could provide more specific examples.",
"Consider addressing the customer's emotional state earlier.",
],
}
return mock_result
prompt = _build_prompt(body)
try:
response = await client.aio.models.generate_content(
model="gemini-2.5-flash",
contents=prompt,
config=types.GenerateContentConfig(
response_mime_type="application/json",
response_schema=SCORE_RESPONSE_SCHEMA,
temperature=0.2,
),
)
result = json.loads(response.text)
if module:
usage = response.usage_metadata
prompt_tokens = usage.prompt_token_count if usage else len(prompt) // 4
completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4
module.ingress_tokens += prompt_tokens
module.egress_tokens += completion_tokens
module.total_tokens += prompt_tokens + completion_tokens
db.commit()
return result
except json.JSONDecodeError as e:
raise HTTPException(status_code=502, detail=f"Gemini returned invalid JSON: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))