1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3 GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4 LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
7use futures_util::StreamExt;
8use reqwest_middleware::ClientBuilder;
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use uuid::Uuid;
13
14const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
17pub struct GeminiConfig {
18 pub api_endpoint: Option<String>,
19 pub api_key: Option<String>,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
23pub enum GeminiModel {
24 #[default]
25 #[serde(rename = "gemini-3-pro-preview")]
26 Gemini3Pro,
27 #[serde(rename = "gemini-3-flash-preview")]
28 Gemini3Flash,
29 #[serde(rename = "gemini-2.5-pro")]
30 Gemini25Pro,
31 #[serde(rename = "gemini-2.5-flash")]
32 Gemini25Flash,
33 #[serde(rename = "gemini-2.5-flash-lite")]
34 Gemini25FlashLite,
35}
36
37impl std::fmt::Display for GeminiModel {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
41 GeminiModel::Gemini3Flash => write!(f, "gemini-3-flash-preview"),
42 GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
43 GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
44 GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
45 }
46 }
47}
48
49impl GeminiModel {
50 pub fn from_string(s: &str) -> Result<Self, String> {
51 serde_json::from_value(serde_json::Value::String(s.to_string()))
52 .map_err(|_| "Failed to deserialize Gemini model".to_string())
53 }
54
55 pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
57
58 pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
60
61 pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
63
64 pub fn default_smart_model() -> String {
66 Self::DEFAULT_SMART_MODEL.to_string()
67 }
68
69 pub fn default_eco_model() -> String {
71 Self::DEFAULT_ECO_MODEL.to_string()
72 }
73
74 pub fn default_recovery_model() -> String {
76 Self::DEFAULT_RECOVERY_MODEL.to_string()
77 }
78}
79
80impl ContextAware for GeminiModel {
81 fn context_info(&self) -> ModelContextInfo {
82 match self {
83 GeminiModel::Gemini3Pro => ModelContextInfo {
84 max_tokens: 1_000_000,
85 pricing_tiers: vec![
86 ContextPricingTier {
87 label: "<200k tokens".to_string(),
88 input_cost_per_million: 2.0,
89 output_cost_per_million: 12.0,
90 upper_bound: Some(200_000),
91 },
92 ContextPricingTier {
93 label: ">200k tokens".to_string(),
94 input_cost_per_million: 4.0,
95 output_cost_per_million: 18.0,
96 upper_bound: None,
97 },
98 ],
99 approach_warning_threshold: 0.8,
100 },
101 GeminiModel::Gemini25Pro => ModelContextInfo {
102 max_tokens: 1_000_000,
103 pricing_tiers: vec![
104 ContextPricingTier {
105 label: "<200k tokens".to_string(),
106 input_cost_per_million: 1.25,
107 output_cost_per_million: 10.0,
108 upper_bound: Some(200_000),
109 },
110 ContextPricingTier {
111 label: ">200k tokens".to_string(),
112 input_cost_per_million: 2.50,
113 output_cost_per_million: 15.0,
114 upper_bound: None,
115 },
116 ],
117 approach_warning_threshold: 0.8,
118 },
119 GeminiModel::Gemini25Flash => ModelContextInfo {
120 max_tokens: 1_000_000,
121 pricing_tiers: vec![ContextPricingTier {
122 label: "Standard".to_string(),
123 input_cost_per_million: 0.30,
124 output_cost_per_million: 2.50,
125 upper_bound: None,
126 }],
127 approach_warning_threshold: 0.8,
128 },
129 GeminiModel::Gemini3Flash => ModelContextInfo {
130 max_tokens: 1_000_000,
131 pricing_tiers: vec![ContextPricingTier {
132 label: "Standard".to_string(),
133 input_cost_per_million: 0.50,
134 output_cost_per_million: 3.0,
135 upper_bound: None,
136 }],
137 approach_warning_threshold: 0.8,
138 },
139 GeminiModel::Gemini25FlashLite => ModelContextInfo {
140 max_tokens: 1_000_000,
141 pricing_tiers: vec![ContextPricingTier {
142 label: "Standard".to_string(),
143 input_cost_per_million: 0.1,
144 output_cost_per_million: 0.4,
145 upper_bound: None,
146 }],
147 approach_warning_threshold: 0.8,
148 },
149 }
150 }
151
152 fn model_name(&self) -> String {
153 match self {
154 GeminiModel::Gemini3Pro => "Gemini 3 Pro".to_string(),
155 GeminiModel::Gemini3Flash => "Gemini 3 Flash".to_string(),
156 GeminiModel::Gemini25Pro => "Gemini 2.5 Pro".to_string(),
157 GeminiModel::Gemini25Flash => "Gemini 2.5 Flash".to_string(),
158 GeminiModel::Gemini25FlashLite => "Gemini 2.5 Flash Lite".to_string(),
159 }
160 }
161}
162
163#[derive(Serialize, Deserialize, Debug)]
164pub struct GeminiInput {
165 pub model: GeminiModel,
166 pub messages: Vec<LLMMessage>,
167 pub max_tokens: u32,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub tools: Option<Vec<LLMTool>>,
170}
171
172#[derive(Serialize, Deserialize, Debug)]
173#[serde(rename_all = "camelCase")]
174pub struct GeminiRequest {
175 pub contents: Vec<GeminiContent>,
176
177 #[serde(skip_serializing_if = "Option::is_none")]
178 pub tools: Option<Vec<GeminiTool>>,
179
180 #[serde(skip_serializing_if = "Option::is_none")]
181 pub system_instruction: Option<GeminiSystemInstruction>, #[serde(skip_serializing_if = "Option::is_none")]
184 pub generation_config: Option<GeminiGenerationConfig>, }
186
187#[derive(Serialize, Deserialize, Debug, Clone)]
188pub enum GeminiRole {
189 #[serde(rename = "user")]
190 User,
191 #[serde(rename = "model")]
192 Model,
193}
194
195impl std::fmt::Display for GeminiRole {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 GeminiRole::User => write!(f, "user"),
199 GeminiRole::Model => write!(f, "model"),
200 }
201 }
202}
203
204impl GeminiRole {
205 pub fn from_string(s: &str) -> Result<Self, String> {
206 serde_json::from_value(serde_json::Value::String(s.to_string()))
207 .map_err(|_| "Failed to deserialize Gemini role".to_string())
208 }
209}
210
211#[derive(Serialize, Deserialize, Debug, Clone)]
212pub struct GeminiContent {
213 pub role: GeminiRole,
214 #[serde(default)]
215 pub parts: Vec<GeminiPart>,
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(untagged)]
220pub enum GeminiPart {
221 Text {
222 text: String,
223 },
224 FunctionCall {
225 #[serde(rename = "functionCall")]
226 function_call: GeminiFunctionCall,
227 },
228 FunctionResponse {
229 #[serde(rename = "functionResponse")]
230 function_response: GeminiFunctionResponse,
231 },
232 InlineData {
233 #[serde(rename = "inlineData")]
234 inline_data: GeminiInlineData,
235 },
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct GeminiFunctionCall {
240 #[serde(default)]
241 pub id: Option<String>,
242 pub name: String,
243 pub args: serde_json::Value,
244}
245
246#[derive(Serialize, Deserialize, Debug, Clone)]
247pub struct GeminiFunctionResponse {
248 pub id: String,
249 pub name: String,
250 pub response: serde_json::Value,
251}
252
253#[derive(Serialize, Deserialize, Debug, Clone)]
254pub struct GeminiInlineData {
255 pub mime_type: String,
256 pub data: String,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260pub struct GeminiSystemInstruction {
261 pub parts: Vec<GeminiPart>, }
263
264#[derive(Serialize, Deserialize, Debug)]
265pub struct GeminiTool {
266 pub function_declarations: Vec<GeminiFunctionDeclaration>,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct GeminiFunctionDeclaration {
271 pub name: String,
272 pub description: String,
273 pub parameters_json_schema: Option<serde_json::Value>,
274}
275
276#[derive(Serialize, Deserialize, Debug)]
277pub struct GeminiGenerationConfig {
278 pub max_output_tokens: Option<u32>,
279 pub temperature: Option<f32>,
280 pub candidate_count: Option<u32>,
281}
282
283#[derive(Serialize, Deserialize, Debug, Clone)]
286#[serde(rename_all = "camelCase")]
287pub struct GeminiResponse {
288 pub candidates: Option<Vec<GeminiCandidate>>,
289 pub usage_metadata: Option<GeminiUsageMetadata>,
290 pub model_version: Option<String>,
291 pub response_id: Option<String>,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295#[serde(rename_all = "camelCase")]
296pub struct GeminiCandidate {
297 pub content: Option<GeminiContent>,
298 pub finish_reason: Option<String>,
299 pub index: Option<u32>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Clone)]
303#[serde(rename_all = "camelCase")]
304pub struct GeminiUsageMetadata {
305 pub prompt_token_count: Option<u32>,
306 pub cached_content_token_count: Option<u32>,
307 pub candidates_token_count: Option<u32>,
308 pub tool_use_prompt_token_count: Option<u32>,
309 pub thoughts_token_count: Option<u32>,
310 pub total_token_count: Option<u32>,
311}
312
313impl From<LLMMessage> for GeminiContent {
314 fn from(message: LLMMessage) -> Self {
315 let role = match message.role.as_str() {
316 "assistant" | "model" => GeminiRole::Model,
317 "user" | "tool" => GeminiRole::User,
318 _ => GeminiRole::User,
319 };
320
321 let parts = match message.content {
322 LLMMessageContent::String(text) => vec![GeminiPart::Text { text }],
323 LLMMessageContent::List(items) => items
324 .into_iter()
325 .map(|item| match item {
326 LLMMessageTypedContent::Text { text } => GeminiPart::Text { text },
327
328 LLMMessageTypedContent::ToolCall { id, name, args } => {
329 GeminiPart::FunctionCall {
330 function_call: GeminiFunctionCall {
331 id: Some(id),
332 name,
333 args,
334 },
335 }
336 }
337
338 LLMMessageTypedContent::ToolResult { content, .. } => {
339 GeminiPart::Text { text: content }
340 }
341
342 LLMMessageTypedContent::Image { source } => GeminiPart::InlineData {
343 inline_data: GeminiInlineData {
344 mime_type: source.media_type,
345 data: source.data,
346 },
347 },
348 })
349 .collect(),
350 };
351
352 GeminiContent { role, parts }
353 }
354}
355
356impl From<GeminiContent> for LLMMessage {
358 fn from(content: GeminiContent) -> Self {
359 let role = content.role.to_string();
360 let mut message_content = Vec::new();
361
362 for part in content.parts {
363 match part {
364 GeminiPart::Text { text } => {
365 message_content.push(LLMMessageTypedContent::Text { text });
366 }
367 GeminiPart::FunctionCall { function_call } => {
368 message_content.push(LLMMessageTypedContent::ToolCall {
369 id: function_call.id.unwrap_or_else(|| "".to_string()),
370 name: function_call.name,
371 args: function_call.args,
372 });
373 }
374 GeminiPart::FunctionResponse { function_response } => {
375 message_content.push(LLMMessageTypedContent::ToolResult {
376 tool_use_id: function_response.id,
377 content: function_response.response.to_string(),
378 });
379 }
380 _ => {}
382 }
383 }
384
385 let content = if message_content.is_empty() {
386 LLMMessageContent::String(String::new())
387 } else if message_content.len() == 1 {
388 match &message_content[0] {
389 LLMMessageTypedContent::Text { text } => LLMMessageContent::String(text.clone()),
390 _ => LLMMessageContent::List(message_content),
391 }
392 } else {
393 LLMMessageContent::List(message_content)
394 };
395
396 LLMMessage { role, content }
397 }
398}
399
400impl From<LLMTool> for GeminiFunctionDeclaration {
401 fn from(tool: LLMTool) -> Self {
402 GeminiFunctionDeclaration {
403 name: tool.name,
404 description: tool.description,
405 parameters_json_schema: Some(tool.input_schema),
406 }
407 }
408}
409
410impl From<Vec<LLMTool>> for GeminiTool {
411 fn from(tools: Vec<LLMTool>) -> Self {
412 GeminiTool {
413 function_declarations: tools.into_iter().map(|t| t.into()).collect(),
414 }
415 }
416}
417
418impl From<GeminiResponse> for LLMCompletionResponse {
419 fn from(response: GeminiResponse) -> Self {
420 let usage = response.usage_metadata.map(|u| LLMTokenUsage {
421 prompt_tokens: u.prompt_token_count.unwrap_or(0),
422 completion_tokens: u.candidates_token_count.unwrap_or(0),
423 total_tokens: u.total_token_count.unwrap_or(0),
424 prompt_tokens_details: None,
425 });
426
427 let choices = response
428 .candidates
429 .unwrap_or_default()
430 .into_iter()
431 .enumerate()
432 .map(|(index, candidate)| {
433 let message = candidate
434 .content
435 .map(|c| c.into())
436 .unwrap_or_else(|| LLMMessage {
437 role: "model".to_string(),
438 content: LLMMessageContent::String(String::new()),
439 });
440
441 let has_tool_calls = match &message.content {
442 LLMMessageContent::List(items) => items
443 .iter()
444 .any(|item| matches!(item, LLMMessageTypedContent::ToolCall { .. })),
445 _ => false,
446 };
447
448 let finish_reason = if has_tool_calls {
449 Some("tool_calls".to_string())
450 } else {
451 candidate.finish_reason.map(|s| s.to_lowercase())
452 };
453
454 LLMChoice {
455 finish_reason,
456 index: index as u32,
457 message,
458 }
459 })
460 .collect();
461
462 LLMCompletionResponse {
463 model: response
465 .model_version
466 .unwrap_or_else(|| "gemini".to_string()),
467 object: "chat.completion".to_string(),
468 choices,
469 created: chrono::Utc::now().timestamp_millis() as u64,
470 usage,
471 id: response
472 .response_id
473 .unwrap_or_else(|| "unknown".to_string()),
474 }
475 }
476}
477
478pub struct Gemini {}
479
480impl Gemini {
481 pub async fn chat(
482 config: &GeminiConfig,
483 input: GeminiInput,
484 ) -> Result<LLMCompletionResponse, AgentError> {
485 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
486 let client = ClientBuilder::new(reqwest::Client::new())
487 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
488 .build();
489
490 let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
491
492 let tools = input.tools.map(|t| vec![t.into()]);
493
494 let payload = GeminiRequest {
495 contents,
496 tools,
497 system_instruction,
498 generation_config: Some(GeminiGenerationConfig {
499 max_output_tokens: Some(input.max_tokens),
500 temperature: Some(0.0),
501 candidate_count: Some(1),
502 }),
503 };
504
505 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
506 let api_key = config.api_key.as_ref().map_or("", |v| v);
507
508 let url = format!(
509 "{}/models/{}:generateContent?key={}",
510 api_endpoint, input.model, api_key
511 );
512
513 let response = client
514 .post(&url)
515 .header("Content-Type", "application/json")
516 .json(&payload)
517 .send()
518 .await
519 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
520
521 if !response.status().is_success() {
522 let status = response.status();
523 let error_text = response.text().await.unwrap_or_default();
524
525 let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
527 if let Some(error_obj) = json.get("error") {
528 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
529 message.to_string()
530 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
531 format!("API error: {}", code)
532 } else {
533 error_text
534 }
535 } else {
536 error_text
537 }
538 } else {
539 error_text
540 };
541
542 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
543 format!("{}: {}", status, error_message),
544 )));
545 }
546
547 let response_text = response.text().await.map_err(|e| {
549 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
550 "Failed to read response body: {}",
551 e
552 )))
553 })?;
554
555 if let Ok(json) = serde_json::from_str::<Value>(&response_text)
557 && let Some(error_obj) = json.get("error")
558 {
559 let error_message = if let Some(message) =
560 error_obj.get("message").and_then(|m| m.as_str())
561 {
562 message.to_string()
563 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
564 format!("API error: {}", code)
565 } else {
566 serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
567 };
568 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
569 error_message,
570 )));
571 }
572
573 let gemini_response: GeminiResponse =
574 serde_json::from_str(&response_text).map_err(|e| {
575 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
576 "Failed to deserialize Gemini response: {}. Response body: {}",
577 e, response_text
578 )))
579 })?;
580
581 Ok(gemini_response.into())
582 }
583
584 pub async fn chat_stream(
585 config: &GeminiConfig,
586 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
587 input: GeminiInput,
588 ) -> Result<LLMCompletionResponse, AgentError> {
589 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
590 let client = ClientBuilder::new(reqwest::Client::new())
591 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
592 .build();
593
594 let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
595
596 let tools = input.tools.map(|t| vec![t.into()]);
597
598 let payload = GeminiRequest {
599 contents,
600 tools,
601 system_instruction,
602 generation_config: Some(GeminiGenerationConfig {
603 max_output_tokens: Some(input.max_tokens),
604 temperature: Some(0.0),
605 candidate_count: Some(1),
606 }),
607 };
608
609 let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
610 let api_key = config.api_key.as_ref().map_or("", |v| v);
611
612 let url = format!(
613 "{}/models/{}:streamGenerateContent?key={}",
614 api_endpoint, input.model, api_key
615 );
616
617 let response = client
618 .post(&url)
619 .header("Content-Type", "application/json")
620 .json(&payload)
621 .send()
622 .await
623 .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
624
625 if !response.status().is_success() {
626 let status = response.status();
627 let error_text = response.text().await.unwrap_or_default();
628
629 let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
631 if let Some(error_obj) = json.get("error") {
632 if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
633 message.to_string()
634 } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
635 format!("API error: {}", code)
636 } else {
637 error_text
638 }
639 } else {
640 error_text
641 }
642 } else {
643 error_text
644 };
645
646 return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
647 format!("{}: {}", status, error_message),
648 )));
649 }
650
651 process_gemini_stream(response, input.model.to_string(), stream_channel_tx).await
652 }
653}
654
655fn convert_messages_to_gemini(
656 messages: Vec<LLMMessage>,
657) -> Result<(Vec<GeminiContent>, Option<GeminiSystemInstruction>), AgentError> {
658 let mut contents = Vec::new();
659 let mut system_parts = Vec::new();
660 let mut tool_id_to_name = std::collections::HashMap::new();
661
662 for message in messages {
663 match message.role.as_str() {
664 "system" => {
665 if let LLMMessageContent::String(text) = message.content {
666 system_parts.push(GeminiPart::Text { text });
667 }
668 }
669 _ => {
670 let role = match message.role.as_str() {
671 "assistant" | "model" => GeminiRole::Model,
672 "user" | "tool" => GeminiRole::User,
673 _ => GeminiRole::User,
674 };
675
676 let mut parts = Vec::new();
677
678 match message.content {
679 LLMMessageContent::String(text) => {
680 parts.push(GeminiPart::Text { text });
681 }
682 LLMMessageContent::List(items) => {
683 for item in items {
684 match item {
685 LLMMessageTypedContent::Text { text } => {
686 parts.push(GeminiPart::Text { text });
687 }
688 LLMMessageTypedContent::ToolCall { id, name, args } => {
689 tool_id_to_name.insert(id.clone(), name.clone());
690 parts.push(GeminiPart::FunctionCall {
691 function_call: GeminiFunctionCall {
692 id: Some(id),
693 name,
694 args,
695 },
696 });
697 }
698 LLMMessageTypedContent::ToolResult {
699 tool_use_id,
700 content,
701 } => {
702 let name = tool_id_to_name
703 .get(&tool_use_id)
704 .cloned()
705 .unwrap_or_else(|| "unknown".to_string());
706
707 let response_json = serde_json::json!({ "result": content });
709
710 parts.push(GeminiPart::FunctionResponse {
711 function_response: GeminiFunctionResponse {
712 id: tool_use_id,
713 name,
714 response: response_json,
715 },
716 });
717 }
718 LLMMessageTypedContent::Image { source } => {
719 parts.push(GeminiPart::InlineData {
720 inline_data: GeminiInlineData {
721 mime_type: source.media_type,
722 data: source.data,
723 },
724 });
725 }
726 }
727 }
728 }
729 }
730
731 contents.push(GeminiContent { role, parts });
732 }
733 }
734 }
735
736 let system_instruction = if system_parts.is_empty() {
737 None
738 } else {
739 Some(GeminiSystemInstruction {
740 parts: system_parts,
741 })
742 };
743
744 Ok((contents, system_instruction))
745}
746
747async fn process_gemini_stream(
748 response: reqwest::Response,
749 model: String,
750 stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
751) -> Result<LLMCompletionResponse, AgentError> {
752 let mut completion_response = LLMCompletionResponse {
753 id: "".to_string(),
754 model: model.clone(),
755 object: "chat.completion".to_string(),
756 choices: vec![],
757 created: chrono::Utc::now().timestamp_millis() as u64,
758 usage: None,
759 };
760
761 let mut stream = response.bytes_stream();
762 let mut line_buffer = String::new();
763 let mut json_accumulator = String::new();
764 let mut brace_depth = 0;
765 let mut in_object = false;
766 let mut finish_reason = None;
767 let mut message_content: Vec<LLMMessageTypedContent> = Vec::new();
768
769 while let Some(chunk) = stream.next().await {
770 let chunk = chunk.map_err(|e| {
771 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
772 "Failed to read stream chunk: {}",
773 e
774 )))
775 })?;
776
777 let text = std::str::from_utf8(&chunk).map_err(|e| {
778 AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
779 "Failed to parse UTF-8: {}",
780 e
781 )))
782 })?;
783
784 line_buffer.push_str(text);
785
786 while let Some(line_end) = line_buffer.find('\n') {
788 let line = line_buffer[..line_end].trim().to_string();
789 line_buffer = line_buffer[line_end + 1..].to_string();
790
791 if line.is_empty() || line == "[" || line == "]" {
793 continue;
794 }
795
796 for ch in line.chars() {
798 match ch {
799 '{' => {
800 brace_depth += 1;
801 in_object = true;
802 }
803 '}' => {
804 brace_depth -= 1;
805 }
806 _ => {}
807 }
808 }
809
810 if in_object {
812 if !json_accumulator.is_empty() {
813 json_accumulator.push('\n');
814 }
815 json_accumulator.push_str(&line);
816 }
817
818 if in_object && brace_depth == 0 {
820 let mut json_str = json_accumulator.trim();
821 if json_str.starts_with('[') {
822 json_str = json_str[1..].trim();
823 }
824 if json_str.ends_with(']') {
825 json_str = json_str[..json_str.len() - 1].trim();
826 }
827 let json_str = json_str.trim_matches(',').trim();
828
829 match serde_json::from_str::<GeminiResponse>(json_str) {
831 Ok(gemini_response) => {
832 if let Some(candidates) = gemini_response.candidates {
834 for candidate in candidates {
835 if let Some(reason) = candidate.finish_reason {
836 finish_reason = Some(reason);
837 }
838 if let Some(content) = candidate.content {
839 for part in content.parts {
840 match part {
841 GeminiPart::Text { text } => {
842 stream_channel_tx
843 .send(GenerationDelta::Content {
844 content: text.clone(),
845 })
846 .await
847 .map_err(|e| {
848 AgentError::BadRequest(
849 BadRequestErrorMessage::ApiError(
850 e.to_string(),
851 ),
852 )
853 })?;
854 message_content
855 .push(LLMMessageTypedContent::Text { text });
856 }
857 GeminiPart::FunctionCall { function_call } => {
858 let GeminiFunctionCall { id, name, args } =
859 function_call;
860
861 let id = id
862 .unwrap_or_else(|| Uuid::new_v4().to_string());
863 let name_clone = name.clone();
864 let args_clone = args.clone();
865 stream_channel_tx
866 .send(GenerationDelta::ToolUse {
867 tool_use: GenerationDeltaToolUse {
868 id: Some(id.clone()),
869 name: Some(name_clone),
870 input: Some(args_clone.to_string()),
871 index: 0,
872 },
873 })
874 .await
875 .map_err(|e| {
876 AgentError::BadRequest(
877 BadRequestErrorMessage::ApiError(
878 e.to_string(),
879 ),
880 )
881 })?;
882 message_content.push(
883 LLMMessageTypedContent::ToolCall {
884 id,
885 name,
886 args,
887 },
888 );
889 }
890 _ => {}
891 }
892 }
893 }
894 }
895 }
896
897 if let Some(usage) = gemini_response.usage_metadata {
899 let token_usage = LLMTokenUsage {
900 prompt_tokens: usage.prompt_token_count.unwrap_or(0),
901 completion_tokens: usage.candidates_token_count.unwrap_or(0),
902 total_tokens: usage.total_token_count.unwrap_or(0),
903 prompt_tokens_details: None,
904 };
905 stream_channel_tx
906 .send(GenerationDelta::Usage {
907 usage: token_usage.clone(),
908 })
909 .await
910 .map_err(|e| {
911 AgentError::BadRequest(BadRequestErrorMessage::ApiError(
912 e.to_string(),
913 ))
914 })?;
915 completion_response.usage = Some(token_usage);
916 }
917
918 if let Some(response_id) = gemini_response.response_id {
920 completion_response.id = response_id;
921 }
922 }
923 Err(e) => {
924 eprintln!("Failed to parse JSON object: {}. Error: {}", json_str, e);
925 }
926 }
927
928 json_accumulator.clear();
930 in_object = false;
931 }
932 }
933 }
934
935 let has_tool_calls = message_content
936 .iter()
937 .any(|c| matches!(c, LLMMessageTypedContent::ToolCall { .. }));
938
939 let final_finish_reason = if has_tool_calls {
940 Some("tool_calls".to_string())
941 } else {
942 finish_reason.map(|s| s.to_lowercase())
943 };
944
945 completion_response.choices = vec![LLMChoice {
947 finish_reason: final_finish_reason,
948 index: 0,
949 message: LLMMessage {
950 role: "assistant".to_string(),
951 content: if message_content.is_empty() {
952 LLMMessageContent::String(String::new())
953 } else if message_content.len() == 1
954 && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
955 {
956 if let LLMMessageTypedContent::Text { text } = &message_content[0] {
957 LLMMessageContent::String(text.clone())
958 } else {
959 LLMMessageContent::List(message_content)
960 }
961 } else {
962 LLMMessageContent::List(message_content)
963 },
964 },
965 }];
966
967 Ok(completion_response)
968}