87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
from fastapi import APIRouter, Depends, Request
|
|
from app.api.deps import get_api_key, get_current_module
|
|
from app.models.module import Module
|
|
from sqlalchemy.orm import Session
|
|
from app.core.database import get_db
|
|
from app.core.limiter import limiter
|
|
from app.core.config import settings
|
|
from pydantic import BaseModel
|
|
from openai import AsyncOpenAI
|
|
import asyncio
|
|
|
|
router = APIRouter()
|
|
|
|
class LLMRequest(BaseModel):
|
|
prompt: str
|
|
context: str = ""
|
|
|
|
# Initialize Async client
|
|
client = None
|
|
if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key":
|
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
|
|
|
@router.post("/chat")
|
|
@limiter.limit(settings.RATE_LIMIT)
|
|
async def openai_chat(
|
|
request: Request,
|
|
chat_data: LLMRequest,
|
|
api_key: str = Depends(get_api_key),
|
|
module: Module = Depends(get_current_module),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
try:
|
|
if not client:
|
|
# Mock response
|
|
response_text = f"MOCK: OpenAI response to '{chat_data.prompt}'"
|
|
if module:
|
|
# Estimate tokens for mock
|
|
prompt_tokens = len(chat_data.prompt) // 4
|
|
completion_tokens = len(response_text) // 4
|
|
module.ingress_tokens += prompt_tokens
|
|
module.egress_tokens += completion_tokens
|
|
module.total_tokens += (prompt_tokens + completion_tokens)
|
|
db.commit()
|
|
|
|
return {
|
|
"status": "mock",
|
|
"model": "openai",
|
|
"response": response_text
|
|
}
|
|
|
|
# Perform Async call to OpenAI
|
|
messages = []
|
|
if chat_data.context:
|
|
messages.append({"role": "system", "content": chat_data.context})
|
|
messages.append({"role": "user", "content": chat_data.prompt})
|
|
|
|
response = await client.chat.completions.create(
|
|
model="gpt-3.5-turbo",
|
|
messages=messages
|
|
)
|
|
|
|
# Track usage
|
|
if module:
|
|
usage = response.usage
|
|
if usage:
|
|
module.ingress_tokens += usage.prompt_tokens
|
|
module.egress_tokens += usage.completion_tokens
|
|
module.total_tokens += usage.total_tokens
|
|
else:
|
|
# Fallback estimation
|
|
total_content = "".join([m["content"] for m in messages])
|
|
prompt_tokens = len(total_content) // 4
|
|
completion_tokens = len(response.choices[0].message.content) // 4
|
|
module.ingress_tokens += prompt_tokens
|
|
module.egress_tokens += completion_tokens
|
|
module.total_tokens += (prompt_tokens + completion_tokens)
|
|
|
|
db.commit()
|
|
|
|
return {
|
|
"status": "success",
|
|
"model": "openai",
|
|
"response": response.choices[0].message.content
|
|
}
|
|
except Exception as e:
|
|
return {"status": "error", "detail": str(e)}
|