diff --git a/app/api/endpoints/gemini.py b/app/api/endpoints/gemini.py index ad73666..da9cb38 100644 --- a/app/api/endpoints/gemini.py +++ b/app/api/endpoints/gemini.py @@ -54,11 +54,15 @@ async def gemini_chat( "response": response_text } + prompt_content = chat_data.prompt + if chat_data.context: + prompt_content = f"Context: {chat_data.context}\n\nPrompt: {chat_data.prompt}" + # Using the async generation method provided by the new google-genai library # We use await to ensure we don't block the event loop response = await client.aio.models.generate_content( model="gemini-2.0-flash", - contents=chat_data.prompt + contents=prompt_content ) # Track usage if valid module @@ -67,7 +71,7 @@ async def gemini_chat( # 1 char ~= 0.25 tokens (rough estimate if exact count not returned) # Gemini response usually has usage_metadata usage = response.usage_metadata - prompt_tokens = usage.prompt_token_count if usage else len(chat_data.prompt) // 4 + prompt_tokens = usage.prompt_token_count if usage else len(prompt_content) // 4 completion_tokens = usage.candidates_token_count if usage else len(response.text) // 4 module.ingress_tokens += prompt_tokens diff --git a/app/api/endpoints/openai.py b/app/api/endpoints/openai.py index 1f213af..2ecf28f 100644 --- a/app/api/endpoints/openai.py +++ b/app/api/endpoints/openai.py @@ -49,9 +49,14 @@ async def openai_chat( } # Perform Async call to OpenAI + messages = [] + if chat_data.context: + messages.append({"role": "system", "content": chat_data.context}) + messages.append({"role": "user", "content": chat_data.prompt}) + response = await client.chat.completions.create( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": chat_data.prompt}] + messages=messages ) # Track usage @@ -61,7 +66,16 @@ async def openai_chat( module.ingress_tokens += usage.prompt_tokens module.egress_tokens += usage.completion_tokens module.total_tokens += usage.total_tokens - db.commit() + else: + # Fallback estimation + total_content = "".join([m["content"] for m in messages]) + prompt_tokens = len(total_content) // 4 + completion_tokens = len(response.choices[0].message.content) // 4 + module.ingress_tokens += prompt_tokens + module.egress_tokens += completion_tokens + module.total_tokens += (prompt_tokens + completion_tokens) + + db.commit() return { "status": "success",