stakpak_shared/models/integrations/
gemini.rs

1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3    GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4    LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
7use futures_util::StreamExt;
8use reqwest_middleware::ClientBuilder;
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use uuid::Uuid;
13
14const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
15
16#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
17pub struct GeminiConfig {
18    pub api_endpoint: Option<String>,
19    pub api_key: Option<String>,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
23pub enum GeminiModel {
24    #[default]
25    #[serde(rename = "gemini-3-pro-preview")]
26    Gemini3Pro,
27    #[serde(rename = "gemini-3-flash-preview")]
28    Gemini3Flash,
29    #[serde(rename = "gemini-2.5-pro")]
30    Gemini25Pro,
31    #[serde(rename = "gemini-2.5-flash")]
32    Gemini25Flash,
33    #[serde(rename = "gemini-2.5-flash-lite")]
34    Gemini25FlashLite,
35}
36
37impl std::fmt::Display for GeminiModel {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
41            GeminiModel::Gemini3Flash => write!(f, "gemini-3-flash-preview"),
42            GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
43            GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
44            GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
45        }
46    }
47}
48
49impl GeminiModel {
50    pub fn from_string(s: &str) -> Result<Self, String> {
51        serde_json::from_value(serde_json::Value::String(s.to_string()))
52            .map_err(|_| "Failed to deserialize Gemini model".to_string())
53    }
54
55    /// Default smart model for Gemini
56    pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
57
58    /// Default eco model for Gemini
59    pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
60
61    /// Default recovery model for Gemini
62    pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
63
64    /// Get default smart model as string
65    pub fn default_smart_model() -> String {
66        Self::DEFAULT_SMART_MODEL.to_string()
67    }
68
69    /// Get default eco model as string
70    pub fn default_eco_model() -> String {
71        Self::DEFAULT_ECO_MODEL.to_string()
72    }
73
74    /// Get default recovery model as string
75    pub fn default_recovery_model() -> String {
76        Self::DEFAULT_RECOVERY_MODEL.to_string()
77    }
78}
79
80impl ContextAware for GeminiModel {
81    fn context_info(&self) -> ModelContextInfo {
82        match self {
83            GeminiModel::Gemini3Pro => ModelContextInfo {
84                max_tokens: 1_000_000,
85                pricing_tiers: vec![
86                    ContextPricingTier {
87                        label: "<200k tokens".to_string(),
88                        input_cost_per_million: 2.0,
89                        output_cost_per_million: 12.0,
90                        upper_bound: Some(200_000),
91                    },
92                    ContextPricingTier {
93                        label: ">200k tokens".to_string(),
94                        input_cost_per_million: 4.0,
95                        output_cost_per_million: 18.0,
96                        upper_bound: None,
97                    },
98                ],
99                approach_warning_threshold: 0.8,
100            },
101            GeminiModel::Gemini25Pro => ModelContextInfo {
102                max_tokens: 1_000_000,
103                pricing_tiers: vec![
104                    ContextPricingTier {
105                        label: "<200k tokens".to_string(),
106                        input_cost_per_million: 1.25,
107                        output_cost_per_million: 10.0,
108                        upper_bound: Some(200_000),
109                    },
110                    ContextPricingTier {
111                        label: ">200k tokens".to_string(),
112                        input_cost_per_million: 2.50,
113                        output_cost_per_million: 15.0,
114                        upper_bound: None,
115                    },
116                ],
117                approach_warning_threshold: 0.8,
118            },
119            GeminiModel::Gemini25Flash => ModelContextInfo {
120                max_tokens: 1_000_000,
121                pricing_tiers: vec![ContextPricingTier {
122                    label: "Standard".to_string(),
123                    input_cost_per_million: 0.30,
124                    output_cost_per_million: 2.50,
125                    upper_bound: None,
126                }],
127                approach_warning_threshold: 0.8,
128            },
129            GeminiModel::Gemini3Flash => ModelContextInfo {
130                max_tokens: 1_000_000,
131                pricing_tiers: vec![ContextPricingTier {
132                    label: "Standard".to_string(),
133                    input_cost_per_million: 0.50,
134                    output_cost_per_million: 3.0,
135                    upper_bound: None,
136                }],
137                approach_warning_threshold: 0.8,
138            },
139            GeminiModel::Gemini25FlashLite => ModelContextInfo {
140                max_tokens: 1_000_000,
141                pricing_tiers: vec![ContextPricingTier {
142                    label: "Standard".to_string(),
143                    input_cost_per_million: 0.1,
144                    output_cost_per_million: 0.4,
145                    upper_bound: None,
146                }],
147                approach_warning_threshold: 0.8,
148            },
149        }
150    }
151
152    fn model_name(&self) -> String {
153        match self {
154            GeminiModel::Gemini3Pro => "Gemini 3 Pro".to_string(),
155            GeminiModel::Gemini3Flash => "Gemini 3 Flash".to_string(),
156            GeminiModel::Gemini25Pro => "Gemini 2.5 Pro".to_string(),
157            GeminiModel::Gemini25Flash => "Gemini 2.5 Flash".to_string(),
158            GeminiModel::Gemini25FlashLite => "Gemini 2.5 Flash Lite".to_string(),
159        }
160    }
161}
162
163#[derive(Serialize, Deserialize, Debug)]
164pub struct GeminiInput {
165    pub model: GeminiModel,
166    pub messages: Vec<LLMMessage>,
167    pub max_tokens: u32,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub tools: Option<Vec<LLMTool>>,
170}
171
172#[derive(Serialize, Deserialize, Debug)]
173#[serde(rename_all = "camelCase")]
174pub struct GeminiRequest {
175    pub contents: Vec<GeminiContent>,
176
177    #[serde(skip_serializing_if = "Option::is_none")]
178    pub tools: Option<Vec<GeminiTool>>,
179
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub system_instruction: Option<GeminiSystemInstruction>, // checked
182
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub generation_config: Option<GeminiGenerationConfig>, // checked
185}
186
187#[derive(Serialize, Deserialize, Debug, Clone)]
188pub enum GeminiRole {
189    #[serde(rename = "user")]
190    User,
191    #[serde(rename = "model")]
192    Model,
193}
194
195impl std::fmt::Display for GeminiRole {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        match self {
198            GeminiRole::User => write!(f, "user"),
199            GeminiRole::Model => write!(f, "model"),
200        }
201    }
202}
203
204impl GeminiRole {
205    pub fn from_string(s: &str) -> Result<Self, String> {
206        serde_json::from_value(serde_json::Value::String(s.to_string()))
207            .map_err(|_| "Failed to deserialize Gemini role".to_string())
208    }
209}
210
211#[derive(Serialize, Deserialize, Debug, Clone)]
212pub struct GeminiContent {
213    pub role: GeminiRole,
214    #[serde(default)]
215    pub parts: Vec<GeminiPart>,
216}
217
218#[derive(Serialize, Deserialize, Debug, Clone)]
219#[serde(untagged)]
220pub enum GeminiPart {
221    Text {
222        text: String,
223    },
224    FunctionCall {
225        #[serde(rename = "functionCall")]
226        function_call: GeminiFunctionCall,
227    },
228    FunctionResponse {
229        #[serde(rename = "functionResponse")]
230        function_response: GeminiFunctionResponse,
231    },
232    InlineData {
233        #[serde(rename = "inlineData")]
234        inline_data: GeminiInlineData,
235    },
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct GeminiFunctionCall {
240    #[serde(default)]
241    pub id: Option<String>,
242    pub name: String,
243    pub args: serde_json::Value,
244}
245
246#[derive(Serialize, Deserialize, Debug, Clone)]
247pub struct GeminiFunctionResponse {
248    pub id: String,
249    pub name: String,
250    pub response: serde_json::Value,
251}
252
253#[derive(Serialize, Deserialize, Debug, Clone)]
254pub struct GeminiInlineData {
255    pub mime_type: String,
256    pub data: String,
257}
258
259#[derive(Serialize, Deserialize, Debug)]
260pub struct GeminiSystemInstruction {
261    pub parts: Vec<GeminiPart>, // checked
262}
263
264#[derive(Serialize, Deserialize, Debug)]
265pub struct GeminiTool {
266    pub function_declarations: Vec<GeminiFunctionDeclaration>,
267}
268
269#[derive(Serialize, Deserialize, Debug)]
270pub struct GeminiFunctionDeclaration {
271    pub name: String,
272    pub description: String,
273    pub parameters_json_schema: Option<serde_json::Value>,
274}
275
276#[derive(Serialize, Deserialize, Debug)]
277pub struct GeminiGenerationConfig {
278    pub max_output_tokens: Option<u32>,
279    pub temperature: Option<f32>,
280    pub candidate_count: Option<u32>,
281}
282
283// Gemini API Response Structs
284
285#[derive(Serialize, Deserialize, Debug, Clone)]
286#[serde(rename_all = "camelCase")]
287pub struct GeminiResponse {
288    pub candidates: Option<Vec<GeminiCandidate>>,
289    pub usage_metadata: Option<GeminiUsageMetadata>,
290    pub model_version: Option<String>,
291    pub response_id: Option<String>,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295#[serde(rename_all = "camelCase")]
296pub struct GeminiCandidate {
297    pub content: Option<GeminiContent>,
298    pub finish_reason: Option<String>,
299    pub index: Option<u32>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Clone)]
303#[serde(rename_all = "camelCase")]
304pub struct GeminiUsageMetadata {
305    pub prompt_token_count: Option<u32>,
306    pub cached_content_token_count: Option<u32>,
307    pub candidates_token_count: Option<u32>,
308    pub tool_use_prompt_token_count: Option<u32>,
309    pub thoughts_token_count: Option<u32>,
310    pub total_token_count: Option<u32>,
311}
312
313impl From<LLMMessage> for GeminiContent {
314    fn from(message: LLMMessage) -> Self {
315        let role = match message.role.as_str() {
316            "assistant" | "model" => GeminiRole::Model,
317            "user" | "tool" => GeminiRole::User,
318            _ => GeminiRole::User,
319        };
320
321        let parts = match message.content {
322            LLMMessageContent::String(text) => vec![GeminiPart::Text { text }],
323            LLMMessageContent::List(items) => items
324                .into_iter()
325                .map(|item| match item {
326                    LLMMessageTypedContent::Text { text } => GeminiPart::Text { text },
327
328                    LLMMessageTypedContent::ToolCall { id, name, args } => {
329                        GeminiPart::FunctionCall {
330                            function_call: GeminiFunctionCall {
331                                id: Some(id),
332                                name,
333                                args,
334                            },
335                        }
336                    }
337
338                    LLMMessageTypedContent::ToolResult { content, .. } => {
339                        GeminiPart::Text { text: content }
340                    }
341
342                    LLMMessageTypedContent::Image { source } => GeminiPart::InlineData {
343                        inline_data: GeminiInlineData {
344                            mime_type: source.media_type,
345                            data: source.data,
346                        },
347                    },
348                })
349                .collect(),
350        };
351
352        GeminiContent { role, parts }
353    }
354}
355
356// Conversion from GeminiContent to LLMMessage
357impl From<GeminiContent> for LLMMessage {
358    fn from(content: GeminiContent) -> Self {
359        let role = content.role.to_string();
360        let mut message_content = Vec::new();
361
362        for part in content.parts {
363            match part {
364                GeminiPart::Text { text } => {
365                    message_content.push(LLMMessageTypedContent::Text { text });
366                }
367                GeminiPart::FunctionCall { function_call } => {
368                    message_content.push(LLMMessageTypedContent::ToolCall {
369                        id: function_call.id.unwrap_or_else(|| "".to_string()),
370                        name: function_call.name,
371                        args: function_call.args,
372                    });
373                }
374                GeminiPart::FunctionResponse { function_response } => {
375                    message_content.push(LLMMessageTypedContent::ToolResult {
376                        tool_use_id: function_response.id,
377                        content: function_response.response.to_string(),
378                    });
379                }
380                //TODO: Add Image support
381                _ => {}
382            }
383        }
384
385        let content = if message_content.is_empty() {
386            LLMMessageContent::String(String::new())
387        } else if message_content.len() == 1 {
388            match &message_content[0] {
389                LLMMessageTypedContent::Text { text } => LLMMessageContent::String(text.clone()),
390                _ => LLMMessageContent::List(message_content),
391            }
392        } else {
393            LLMMessageContent::List(message_content)
394        };
395
396        LLMMessage { role, content }
397    }
398}
399
400impl From<LLMTool> for GeminiFunctionDeclaration {
401    fn from(tool: LLMTool) -> Self {
402        GeminiFunctionDeclaration {
403            name: tool.name,
404            description: tool.description,
405            parameters_json_schema: Some(tool.input_schema),
406        }
407    }
408}
409
410impl From<Vec<LLMTool>> for GeminiTool {
411    fn from(tools: Vec<LLMTool>) -> Self {
412        GeminiTool {
413            function_declarations: tools.into_iter().map(|t| t.into()).collect(),
414        }
415    }
416}
417
418impl From<GeminiResponse> for LLMCompletionResponse {
419    fn from(response: GeminiResponse) -> Self {
420        let usage = response.usage_metadata.map(|u| LLMTokenUsage {
421            prompt_tokens: u.prompt_token_count.unwrap_or(0),
422            completion_tokens: u.candidates_token_count.unwrap_or(0),
423            total_tokens: u.total_token_count.unwrap_or(0),
424            prompt_tokens_details: None,
425        });
426
427        let choices = response
428            .candidates
429            .unwrap_or_default()
430            .into_iter()
431            .enumerate()
432            .map(|(index, candidate)| {
433                let message = candidate
434                    .content
435                    .map(|c| c.into())
436                    .unwrap_or_else(|| LLMMessage {
437                        role: "model".to_string(),
438                        content: LLMMessageContent::String(String::new()),
439                    });
440
441                let has_tool_calls = match &message.content {
442                    LLMMessageContent::List(items) => items
443                        .iter()
444                        .any(|item| matches!(item, LLMMessageTypedContent::ToolCall { .. })),
445                    _ => false,
446                };
447
448                let finish_reason = if has_tool_calls {
449                    Some("tool_calls".to_string())
450                } else {
451                    candidate.finish_reason.map(|s| s.to_lowercase())
452                };
453
454                LLMChoice {
455                    finish_reason,
456                    index: index as u32,
457                    message,
458                }
459            })
460            .collect();
461
462        LLMCompletionResponse {
463            // Use model_version from the response, with fallback
464            model: response
465                .model_version
466                .unwrap_or_else(|| "gemini".to_string()),
467            object: "chat.completion".to_string(),
468            choices,
469            created: chrono::Utc::now().timestamp_millis() as u64,
470            usage,
471            id: response
472                .response_id
473                .unwrap_or_else(|| "unknown".to_string()),
474        }
475    }
476}
477
478pub struct Gemini {}
479
480impl Gemini {
481    pub async fn chat(
482        config: &GeminiConfig,
483        input: GeminiInput,
484    ) -> Result<LLMCompletionResponse, AgentError> {
485        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
486        let client = ClientBuilder::new(reqwest::Client::new())
487            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
488            .build();
489
490        let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
491
492        let tools = input.tools.map(|t| vec![t.into()]);
493
494        let payload = GeminiRequest {
495            contents,
496            tools,
497            system_instruction,
498            generation_config: Some(GeminiGenerationConfig {
499                max_output_tokens: Some(input.max_tokens),
500                temperature: Some(0.0),
501                candidate_count: Some(1),
502            }),
503        };
504
505        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
506        let api_key = config.api_key.as_ref().map_or("", |v| v);
507
508        let url = format!(
509            "{}/models/{}:generateContent?key={}",
510            api_endpoint, input.model, api_key
511        );
512
513        let response = client
514            .post(&url)
515            .header("Content-Type", "application/json")
516            .json(&payload)
517            .send()
518            .await
519            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
520
521        if !response.status().is_success() {
522            let status = response.status();
523            let error_text = response.text().await.unwrap_or_default();
524
525            // Try to parse as JSON and extract error message
526            let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
527                if let Some(error_obj) = json.get("error") {
528                    if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
529                        message.to_string()
530                    } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
531                        format!("API error: {}", code)
532                    } else {
533                        error_text
534                    }
535                } else {
536                    error_text
537                }
538            } else {
539                error_text
540            };
541
542            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
543                format!("{}: {}", status, error_message),
544            )));
545        }
546
547        // Log response body before attempting to decode
548        let response_text = response.text().await.map_err(|e| {
549            AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
550                "Failed to read response body: {}",
551                e
552            )))
553        })?;
554
555        // Check if the response contains an error field before deserializing
556        if let Ok(json) = serde_json::from_str::<Value>(&response_text)
557            && let Some(error_obj) = json.get("error")
558        {
559            let error_message = if let Some(message) =
560                error_obj.get("message").and_then(|m| m.as_str())
561            {
562                message.to_string()
563            } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
564                format!("API error: {}", code)
565            } else {
566                serde_json::to_string(error_obj).unwrap_or_else(|_| "Unknown API error".to_string())
567            };
568            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
569                error_message,
570            )));
571        }
572
573        let gemini_response: GeminiResponse =
574            serde_json::from_str(&response_text).map_err(|e| {
575                AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
576                    "Failed to deserialize Gemini response: {}. Response body: {}",
577                    e, response_text
578                )))
579            })?;
580
581        Ok(gemini_response.into())
582    }
583
584    pub async fn chat_stream(
585        config: &GeminiConfig,
586        stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
587        input: GeminiInput,
588    ) -> Result<LLMCompletionResponse, AgentError> {
589        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
590        let client = ClientBuilder::new(reqwest::Client::new())
591            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
592            .build();
593
594        let (contents, system_instruction) = convert_messages_to_gemini(input.messages)?;
595
596        let tools = input.tools.map(|t| vec![t.into()]);
597
598        let payload = GeminiRequest {
599            contents,
600            tools,
601            system_instruction,
602            generation_config: Some(GeminiGenerationConfig {
603                max_output_tokens: Some(input.max_tokens),
604                temperature: Some(0.0),
605                candidate_count: Some(1),
606            }),
607        };
608
609        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
610        let api_key = config.api_key.as_ref().map_or("", |v| v);
611
612        let url = format!(
613            "{}/models/{}:streamGenerateContent?key={}",
614            api_endpoint, input.model, api_key
615        );
616
617        let response = client
618            .post(&url)
619            .header("Content-Type", "application/json")
620            .json(&payload)
621            .send()
622            .await
623            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
624
625        if !response.status().is_success() {
626            let status = response.status();
627            let error_text = response.text().await.unwrap_or_default();
628
629            // Try to parse as JSON and extract error message
630            let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
631                if let Some(error_obj) = json.get("error") {
632                    if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
633                        message.to_string()
634                    } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
635                        format!("API error: {}", code)
636                    } else {
637                        error_text
638                    }
639                } else {
640                    error_text
641                }
642            } else {
643                error_text
644            };
645
646            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
647                format!("{}: {}", status, error_message),
648            )));
649        }
650
651        process_gemini_stream(response, input.model.to_string(), stream_channel_tx).await
652    }
653}
654
655fn convert_messages_to_gemini(
656    messages: Vec<LLMMessage>,
657) -> Result<(Vec<GeminiContent>, Option<GeminiSystemInstruction>), AgentError> {
658    let mut contents = Vec::new();
659    let mut system_parts = Vec::new();
660    let mut tool_id_to_name = std::collections::HashMap::new();
661
662    for message in messages {
663        match message.role.as_str() {
664            "system" => {
665                if let LLMMessageContent::String(text) = message.content {
666                    system_parts.push(GeminiPart::Text { text });
667                }
668            }
669            _ => {
670                let role = match message.role.as_str() {
671                    "assistant" | "model" => GeminiRole::Model,
672                    "user" | "tool" => GeminiRole::User,
673                    _ => GeminiRole::User,
674                };
675
676                let mut parts = Vec::new();
677
678                match message.content {
679                    LLMMessageContent::String(text) => {
680                        parts.push(GeminiPart::Text { text });
681                    }
682                    LLMMessageContent::List(items) => {
683                        for item in items {
684                            match item {
685                                LLMMessageTypedContent::Text { text } => {
686                                    parts.push(GeminiPart::Text { text });
687                                }
688                                LLMMessageTypedContent::ToolCall { id, name, args } => {
689                                    tool_id_to_name.insert(id.clone(), name.clone());
690                                    parts.push(GeminiPart::FunctionCall {
691                                        function_call: GeminiFunctionCall {
692                                            id: Some(id),
693                                            name,
694                                            args,
695                                        },
696                                    });
697                                }
698                                LLMMessageTypedContent::ToolResult {
699                                    tool_use_id,
700                                    content,
701                                } => {
702                                    let name = tool_id_to_name
703                                        .get(&tool_use_id)
704                                        .cloned()
705                                        .unwrap_or_else(|| "unknown".to_string());
706
707                                    // Gemini expects a JSON object for the response
708                                    let response_json = serde_json::json!({ "result": content });
709
710                                    parts.push(GeminiPart::FunctionResponse {
711                                        function_response: GeminiFunctionResponse {
712                                            id: tool_use_id,
713                                            name,
714                                            response: response_json,
715                                        },
716                                    });
717                                }
718                                LLMMessageTypedContent::Image { source } => {
719                                    parts.push(GeminiPart::InlineData {
720                                        inline_data: GeminiInlineData {
721                                            mime_type: source.media_type,
722                                            data: source.data,
723                                        },
724                                    });
725                                }
726                            }
727                        }
728                    }
729                }
730
731                contents.push(GeminiContent { role, parts });
732            }
733        }
734    }
735
736    let system_instruction = if system_parts.is_empty() {
737        None
738    } else {
739        Some(GeminiSystemInstruction {
740            parts: system_parts,
741        })
742    };
743
744    Ok((contents, system_instruction))
745}
746
747async fn process_gemini_stream(
748    response: reqwest::Response,
749    model: String,
750    stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
751) -> Result<LLMCompletionResponse, AgentError> {
752    let mut completion_response = LLMCompletionResponse {
753        id: "".to_string(),
754        model: model.clone(),
755        object: "chat.completion".to_string(),
756        choices: vec![],
757        created: chrono::Utc::now().timestamp_millis() as u64,
758        usage: None,
759    };
760
761    let mut stream = response.bytes_stream();
762    let mut line_buffer = String::new();
763    let mut json_accumulator = String::new();
764    let mut brace_depth = 0;
765    let mut in_object = false;
766    let mut finish_reason = None;
767    let mut message_content: Vec<LLMMessageTypedContent> = Vec::new();
768
769    while let Some(chunk) = stream.next().await {
770        let chunk = chunk.map_err(|e| {
771            AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
772                "Failed to read stream chunk: {}",
773                e
774            )))
775        })?;
776
777        let text = std::str::from_utf8(&chunk).map_err(|e| {
778            AgentError::BadRequest(BadRequestErrorMessage::ApiError(format!(
779                "Failed to parse UTF-8: {}",
780                e
781            )))
782        })?;
783
784        line_buffer.push_str(text);
785
786        // Process complete lines from buffer
787        while let Some(line_end) = line_buffer.find('\n') {
788            let line = line_buffer[..line_end].trim().to_string();
789            line_buffer = line_buffer[line_end + 1..].to_string();
790
791            // Skip empty lines and array delimiters
792            if line.is_empty() || line == "[" || line == "]" {
793                continue;
794            }
795
796            // Track braces to detect complete JSON objects
797            for ch in line.chars() {
798                match ch {
799                    '{' => {
800                        brace_depth += 1;
801                        in_object = true;
802                    }
803                    '}' => {
804                        brace_depth -= 1;
805                    }
806                    _ => {}
807                }
808            }
809
810            // Accumulate JSON lines
811            if in_object {
812                if !json_accumulator.is_empty() {
813                    json_accumulator.push('\n');
814                }
815                json_accumulator.push_str(&line);
816            }
817
818            // When we reach depth 0, we have a complete JSON object
819            if in_object && brace_depth == 0 {
820                let mut json_str = json_accumulator.trim();
821                if json_str.starts_with('[') {
822                    json_str = json_str[1..].trim();
823                }
824                if json_str.ends_with(']') {
825                    json_str = json_str[..json_str.len() - 1].trim();
826                }
827                let json_str = json_str.trim_matches(',').trim();
828
829                // Try to parse the complete JSON object
830                match serde_json::from_str::<GeminiResponse>(json_str) {
831                    Ok(gemini_response) => {
832                        // Process candidates
833                        if let Some(candidates) = gemini_response.candidates {
834                            for candidate in candidates {
835                                if let Some(reason) = candidate.finish_reason {
836                                    finish_reason = Some(reason);
837                                }
838                                if let Some(content) = candidate.content {
839                                    for part in content.parts {
840                                        match part {
841                                            GeminiPart::Text { text } => {
842                                                stream_channel_tx
843                                                    .send(GenerationDelta::Content {
844                                                        content: text.clone(),
845                                                    })
846                                                    .await
847                                                    .map_err(|e| {
848                                                        AgentError::BadRequest(
849                                                            BadRequestErrorMessage::ApiError(
850                                                                e.to_string(),
851                                                            ),
852                                                        )
853                                                    })?;
854                                                message_content
855                                                    .push(LLMMessageTypedContent::Text { text });
856                                            }
857                                            GeminiPart::FunctionCall { function_call } => {
858                                                let GeminiFunctionCall { id, name, args } =
859                                                    function_call;
860
861                                                let id = id
862                                                    .unwrap_or_else(|| Uuid::new_v4().to_string());
863                                                let name_clone = name.clone();
864                                                let args_clone = args.clone();
865                                                stream_channel_tx
866                                                    .send(GenerationDelta::ToolUse {
867                                                        tool_use: GenerationDeltaToolUse {
868                                                            id: Some(id.clone()),
869                                                            name: Some(name_clone),
870                                                            input: Some(args_clone.to_string()),
871                                                            index: 0,
872                                                        },
873                                                    })
874                                                    .await
875                                                    .map_err(|e| {
876                                                        AgentError::BadRequest(
877                                                            BadRequestErrorMessage::ApiError(
878                                                                e.to_string(),
879                                                            ),
880                                                        )
881                                                    })?;
882                                                message_content.push(
883                                                    LLMMessageTypedContent::ToolCall {
884                                                        id,
885                                                        name,
886                                                        args,
887                                                    },
888                                                );
889                                            }
890                                            _ => {}
891                                        }
892                                    }
893                                }
894                            }
895                        }
896
897                        // Update usage metadata
898                        if let Some(usage) = gemini_response.usage_metadata {
899                            let token_usage = LLMTokenUsage {
900                                prompt_tokens: usage.prompt_token_count.unwrap_or(0),
901                                completion_tokens: usage.candidates_token_count.unwrap_or(0),
902                                total_tokens: usage.total_token_count.unwrap_or(0),
903                                prompt_tokens_details: None,
904                            };
905                            stream_channel_tx
906                                .send(GenerationDelta::Usage {
907                                    usage: token_usage.clone(),
908                                })
909                                .await
910                                .map_err(|e| {
911                                    AgentError::BadRequest(BadRequestErrorMessage::ApiError(
912                                        e.to_string(),
913                                    ))
914                                })?;
915                            completion_response.usage = Some(token_usage);
916                        }
917
918                        // Update response ID if available
919                        if let Some(response_id) = gemini_response.response_id {
920                            completion_response.id = response_id;
921                        }
922                    }
923                    Err(e) => {
924                        eprintln!("Failed to parse JSON object: {}. Error: {}", json_str, e);
925                    }
926                }
927
928                // Reset for next object
929                json_accumulator.clear();
930                in_object = false;
931            }
932        }
933    }
934
935    let has_tool_calls = message_content
936        .iter()
937        .any(|c| matches!(c, LLMMessageTypedContent::ToolCall { .. }));
938
939    let final_finish_reason = if has_tool_calls {
940        Some("tool_calls".to_string())
941    } else {
942        finish_reason.map(|s| s.to_lowercase())
943    };
944
945    // Build final message content
946    completion_response.choices = vec![LLMChoice {
947        finish_reason: final_finish_reason,
948        index: 0,
949        message: LLMMessage {
950            role: "assistant".to_string(),
951            content: if message_content.is_empty() {
952                LLMMessageContent::String(String::new())
953            } else if message_content.len() == 1
954                && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
955            {
956                if let LLMMessageTypedContent::Text { text } = &message_content[0] {
957                    LLMMessageContent::String(text.clone())
958                } else {
959                    LLMMessageContent::List(message_content)
960                }
961            } else {
962                LLMMessageContent::List(message_content)
963            },
964        },
965    }];
966
967    Ok(completion_response)
968}