Skip to main content

sgr_agent/
gemini.rs

1//! Gemini API client — structured output + function calling.
2//!
3//! Supports both Google AI Studio (API key) and Vertex AI (ADC).
4//!
5//! Two modes combined:
6//! - **Structured output**: `generationConfig.responseMimeType = "application/json"`
7//!   + `responseSchema` — forces model to return JSON matching the SGR envelope.
8//! - **Function calling**: `tools[].functionDeclarations` — model emits `functionCall`
9//!   parts that map to your Rust tool structs.
10//!
11//! The model can return BOTH structured text AND function calls in one response.
12
13use crate::schema::response_schema_for;
14use crate::tool::ToolDef;
15use crate::types::*;
16use schemars::JsonSchema;
17use serde::de::DeserializeOwned;
18use serde_json::{json, Value};
19
20/// Gemini API client.
21pub struct GeminiClient {
22    config: ProviderConfig,
23    http: reqwest::Client,
24}
25
26impl GeminiClient {
27    pub fn new(config: ProviderConfig) -> Self {
28        Self {
29            config,
30            http: reqwest::Client::new(),
31        }
32    }
33
34    /// Quick constructor for Google AI Studio (API key auth).
35    pub fn from_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
36        Self::new(ProviderConfig::gemini(api_key, model))
37    }
38
39    /// SGR call: structured output (typed response) + function calling (tools).
40    ///
41    /// Returns `SgrResponse<T>` where:
42    /// - `output`: parsed structured response (if model returned text)
43    /// - `tool_calls`: function calls (if model used tools)
44    ///
45    /// The model may return either or both.
46    pub async fn call<T: JsonSchema + DeserializeOwned>(
47        &self,
48        messages: &[Message],
49        tools: &[ToolDef],
50    ) -> Result<SgrResponse<T>, SgrError> {
51        let body = self.build_request::<T>(messages, tools)?;
52        let url = self.build_url();
53
54        tracing::debug!(url = %url, model = %self.config.model, "gemini_request");
55
56        let mut req = self.http.post(&url).json(&body);
57        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
58            req = req.bearer_auth(&self.config.api_key);
59        }
60        let response = req.send().await?;
61
62        let status = response.status().as_u16();
63        let headers = response.headers().clone();
64        if status != 200 {
65            let body = response.text().await.unwrap_or_default();
66            return Err(SgrError::from_response_parts(status, body, &headers));
67        }
68
69        let response_body: Value = response.json().await?;
70        let rate_limit = RateLimitInfo::from_headers(&headers);
71        self.parse_response(&response_body, rate_limit)
72    }
73
74    /// SGR call with structured output only (no tools).
75    ///
76    /// Shorthand for `call::<T>(messages, &[])`.
77    pub async fn structured<T: JsonSchema + DeserializeOwned>(
78        &self,
79        messages: &[Message],
80    ) -> Result<T, SgrError> {
81        let resp = self.call::<T>(messages, &[]).await?;
82        resp.output.ok_or(SgrError::EmptyResponse)
83    }
84
85    /// Flexible call: no structured output API, parse JSON from raw text.
86    ///
87    /// For use with text-only proxies (CLI proxy, Codex proxy) where
88    /// the model can't enforce JSON schema. Uses AnyOf cascade + coercion.
89    ///
90    /// Auto-injects JSON Schema into the system prompt so the model knows
91    /// the expected format (like BAML does).
92    pub async fn flexible<T: JsonSchema + DeserializeOwned>(
93        &self,
94        messages: &[Message],
95    ) -> Result<SgrResponse<T>, SgrError> {
96        // Send without responseSchema — plain text response
97        // Use text mode for tool messages (no functionDeclarations in this mode)
98        let contents = self.messages_to_contents_text(messages);
99        let mut system_instruction = self.extract_system(messages);
100
101        // Auto-inject schema hint into system prompt
102        let schema = response_schema_for::<T>();
103        let schema_hint = format!(
104            "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks.",
105            serde_json::to_string_pretty(&schema).unwrap_or_default()
106        );
107        system_instruction = Some(match system_instruction {
108            Some(s) => format!("{}{}", s, schema_hint),
109            None => schema_hint,
110        });
111
112        let mut gen_config = json!({
113            "temperature": self.config.temperature,
114        });
115        if let Some(max_tokens) = self.config.max_tokens {
116            gen_config["maxOutputTokens"] = json!(max_tokens);
117        }
118
119        let mut body = json!({
120            "contents": contents,
121            "generationConfig": gen_config,
122        });
123        if let Some(system) = system_instruction {
124            body["systemInstruction"] = json!({
125                "parts": [{"text": system}]
126            });
127        }
128
129        let url = self.build_url();
130        let mut req = self.http.post(&url).json(&body);
131        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
132            req = req.bearer_auth(&self.config.api_key);
133        }
134        let response = req.send().await?;
135        let status = response.status().as_u16();
136        let headers = response.headers().clone();
137        if status != 200 {
138            let body = response.text().await.unwrap_or_default();
139            return Err(SgrError::from_response_parts(status, body, &headers));
140        }
141
142        let response_body: Value = response.json().await?;
143        let rate_limit = RateLimitInfo::from_headers(&headers);
144
145        // Extract raw text
146        let raw_text = self.extract_raw_text(&response_body);
147        if raw_text.trim().is_empty() {
148            // Log finish reason and response parts for debugging
149            if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
150                let reason = candidate
151                    .get("finishReason")
152                    .and_then(|r| r.as_str())
153                    .unwrap_or("unknown");
154                tracing::warn!(
155                    finish_reason = reason,
156                    has_parts = candidate
157                        .get("content")
158                        .and_then(|c| c.get("parts"))
159                        .is_some(),
160                    "empty raw_text from Gemini"
161                );
162            }
163        }
164        let usage = response_body.get("usageMetadata").and_then(|u| {
165            Some(Usage {
166                prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
167                completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
168                total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
169            })
170        });
171
172        // Extract native function calls (Gemini may use functionCall parts
173        // even without explicit functionDeclarations — especially newer models).
174        let tool_calls = self.extract_tool_calls(&response_body);
175
176        // Flexible parse with coercion.
177        // If parsing fails, return output=None with raw_text preserved
178        // so callers can implement fallback logic (e.g. wrap in finish tool).
179        let output = crate::flexible_parser::parse_flexible_coerced::<T>(&raw_text)
180            .map(|r| r.value)
181            .ok();
182
183        if output.is_none() && raw_text.trim().is_empty() && tool_calls.is_empty() {
184            // Log raw response for debugging
185            let parts_summary = response_body
186                .get("candidates")
187                .and_then(|c| c.get(0))
188                .and_then(|c| c.get("content"))
189                .and_then(|c| c.get("parts"))
190                .and_then(|p| p.as_array())
191                .map(|parts| {
192                    parts
193                        .iter()
194                        .map(|p| {
195                            if p.get("text").is_some() {
196                                "text".to_string()
197                            } else if p.get("functionCall").is_some() {
198                                format!(
199                                    "functionCall:{}",
200                                    p["functionCall"]["name"].as_str().unwrap_or("?")
201                                )
202                            } else {
203                                format!("unknown:{}", p)
204                            }
205                        })
206                        .collect::<Vec<_>>()
207                        .join(", ")
208                })
209                .unwrap_or_else(|| "no parts".into());
210            // Log full candidate for debugging
211            let candidate_json = response_body
212                .get("candidates")
213                .and_then(|c| c.get(0))
214                .map(|c| serde_json::to_string_pretty(c).unwrap_or_default())
215                .unwrap_or_else(|| "no candidates".into());
216            tracing::error!(
217                parts = parts_summary,
218                candidate = candidate_json.as_str(),
219                "SGR empty response"
220            );
221            return Err(SgrError::Schema(format!(
222                "Empty response from model (parts: {})",
223                parts_summary
224            )));
225        }
226
227        Ok(SgrResponse {
228            output,
229            tool_calls,
230            raw_text,
231            usage,
232            rate_limit,
233        })
234    }
235
236    /// Tool-only call: no structured output schema, just function calling.
237    ///
238    /// Returns raw tool calls.
239    pub async fn tools_call(
240        &self,
241        messages: &[Message],
242        tools: &[ToolDef],
243    ) -> Result<Vec<ToolCall>, SgrError> {
244        let body = self.build_tools_only_request(messages, tools)?;
245        let url = self.build_url();
246
247        let mut req = self.http.post(&url).json(&body);
248        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
249            req = req.bearer_auth(&self.config.api_key);
250        }
251        let response = req.send().await?;
252        let status = response.status().as_u16();
253        let headers = response.headers().clone();
254        if status != 200 {
255            let body = response.text().await.unwrap_or_default();
256            return Err(SgrError::from_response_parts(status, body, &headers));
257        }
258
259        let response_body: Value = response.json().await?;
260        Ok(self.extract_tool_calls(&response_body))
261    }
262
263    // --- Private ---
264
265    fn build_url(&self) -> String {
266        if let Some(project_id) = &self.config.project_id {
267            // Vertex AI
268            let location = self.config.location.as_deref().unwrap_or("global");
269            let host = if location == "global" {
270                "aiplatform.googleapis.com".to_string()
271            } else {
272                format!("{location}-aiplatform.googleapis.com")
273            };
274            format!(
275                "https://{host}/v1/projects/{project_id}/locations/{location}/publishers/google/models/{}:generateContent",
276                self.config.model
277            )
278        } else {
279            // Google AI Studio
280            format!(
281                "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
282                self.config.model, self.config.api_key
283            )
284        }
285    }
286
287    fn build_request<T: JsonSchema>(
288        &self,
289        messages: &[Message],
290        tools: &[ToolDef],
291    ) -> Result<Value, SgrError> {
292        // Use functionResponse format only when tools are present
293        let contents = if tools.is_empty() {
294            self.messages_to_contents_text(messages)
295        } else {
296            self.messages_to_contents(messages)
297        };
298        let system_instruction = self.extract_system(messages);
299
300        // When using function calling, Gemini doesn't support responseMimeType + functionDeclarations.
301        // Use structured output (JSON mode) only when there are no tools.
302        let mut gen_config = json!({
303            "temperature": self.config.temperature,
304        });
305
306        if tools.is_empty() {
307            gen_config["responseMimeType"] = json!("application/json");
308            gen_config["responseSchema"] = response_schema_for::<T>();
309        }
310
311        if let Some(max_tokens) = self.config.max_tokens {
312            gen_config["maxOutputTokens"] = json!(max_tokens);
313        }
314
315        let mut body = json!({
316            "contents": contents,
317            "generationConfig": gen_config,
318        });
319
320        if let Some(system) = system_instruction {
321            body["systemInstruction"] = json!({
322                "parts": [{"text": system}]
323            });
324        }
325
326        if !tools.is_empty() {
327            let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
328            body["tools"] = json!([{
329                "functionDeclarations": function_declarations,
330            }]);
331            body["toolConfig"] = json!({
332                "functionCallingConfig": {
333                    "mode": "AUTO"
334                }
335            });
336        }
337
338        Ok(body)
339    }
340
341    fn build_tools_only_request(
342        &self,
343        messages: &[Message],
344        tools: &[ToolDef],
345    ) -> Result<Value, SgrError> {
346        let contents = self.messages_to_contents(messages);
347        let system_instruction = self.extract_system(messages);
348
349        let mut gen_config = json!({
350            "temperature": self.config.temperature,
351        });
352        if let Some(max_tokens) = self.config.max_tokens {
353            gen_config["maxOutputTokens"] = json!(max_tokens);
354        }
355
356        let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
357
358        let mut body = json!({
359            "contents": contents,
360            "generationConfig": gen_config,
361            "tools": [{
362                "functionDeclarations": function_declarations,
363            }],
364            "toolConfig": {
365                "functionCallingConfig": {
366                    "mode": "ANY"
367                }
368            }
369        });
370
371        if let Some(system) = system_instruction {
372            body["systemInstruction"] = json!({
373                "parts": [{"text": system}]
374            });
375        }
376
377        Ok(body)
378    }
379
380    /// Convert messages to Gemini contents format.
381    ///
382    /// When `use_function_response` is true, Tool messages become `functionResponse` parts
383    /// (for native function calling mode). When false, they become user text messages
384    /// (for structured output / flexible mode where no function declarations are sent).
385    fn messages_to_contents(&self, messages: &[Message]) -> Vec<Value> {
386        self.messages_to_contents_inner(messages, true)
387    }
388
389    fn messages_to_contents_text(&self, messages: &[Message]) -> Vec<Value> {
390        self.messages_to_contents_inner(messages, false)
391    }
392
393    fn messages_to_contents_inner(
394        &self,
395        messages: &[Message],
396        use_function_response: bool,
397    ) -> Vec<Value> {
398        let mut contents = Vec::new();
399
400        let mut i = 0;
401        while i < messages.len() {
402            let msg = &messages[i];
403            match msg.role {
404                Role::System => {
405                    i += 1;
406                } // handled separately via systemInstruction
407                Role::User => {
408                    contents.push(json!({
409                        "role": "user",
410                        "parts": [{"text": msg.content}]
411                    }));
412                    i += 1;
413                }
414                Role::Assistant => {
415                    if use_function_response && !msg.tool_calls.is_empty() {
416                        // Model turn with function calls — include functionCall parts
417                        let mut parts: Vec<Value> = Vec::new();
418                        if !msg.content.is_empty() {
419                            parts.push(json!({"text": msg.content}));
420                        }
421                        for tc in &msg.tool_calls {
422                            parts.push(json!({
423                                "functionCall": {
424                                    "name": tc.name,
425                                    "args": tc.arguments
426                                }
427                            }));
428                        }
429                        contents.push(json!({
430                            "role": "model",
431                            "parts": parts
432                        }));
433                    } else {
434                        contents.push(json!({
435                            "role": "model",
436                            "parts": [{"text": msg.content}]
437                        }));
438                    }
439                    i += 1;
440                }
441                Role::Tool => {
442                    if use_function_response {
443                        // Gemini requires ALL functionResponses for one turn in a SINGLE
444                        // "function" content entry. Collect consecutive Tool messages.
445                        let mut parts = Vec::new();
446                        while i < messages.len() && messages[i].role == Role::Tool {
447                            let tool_msg = &messages[i];
448                            let call_id = tool_msg.tool_call_id.as_deref().unwrap_or("unknown");
449                            let func_name = match call_id.split('#').collect::<Vec<_>>().as_slice()
450                            {
451                                ["call", name, _counter] => *name,
452                                _ => call_id,
453                            };
454                            parts.push(json!({
455                                "functionResponse": {
456                                    "name": func_name,
457                                    "response": {
458                                        "content": tool_msg.content,
459                                    }
460                                }
461                            }));
462                            i += 1;
463                        }
464                        contents.push(json!({
465                            "role": "function",
466                            "parts": parts
467                        }));
468                    } else {
469                        // Text mode — convert tool results to user messages
470                        let call_id = msg.tool_call_id.as_deref().unwrap_or("tool");
471                        contents.push(json!({
472                            "role": "user",
473                            "parts": [{"text": format!("[{}] {}", call_id, msg.content)}]
474                        }));
475                        i += 1;
476                    }
477                }
478            }
479        }
480
481        contents
482    }
483
484    fn extract_system(&self, messages: &[Message]) -> Option<String> {
485        let system_parts: Vec<&str> = messages
486            .iter()
487            .filter(|m| m.role == Role::System)
488            .map(|m| m.content.as_str())
489            .collect();
490
491        if system_parts.is_empty() {
492            None
493        } else {
494            Some(system_parts.join("\n\n"))
495        }
496    }
497
498    fn parse_response<T: DeserializeOwned>(
499        &self,
500        body: &Value,
501        rate_limit: Option<RateLimitInfo>,
502    ) -> Result<SgrResponse<T>, SgrError> {
503        let mut output: Option<T> = None;
504        let mut tool_calls = Vec::new();
505        let mut raw_text = String::new();
506        let mut call_counter: u32 = 0;
507
508        // Parse usage
509        let usage = body.get("usageMetadata").and_then(|u| {
510            Some(Usage {
511                prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
512                completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
513                total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
514            })
515        });
516
517        // Extract from candidates
518        let candidates = body
519            .get("candidates")
520            .and_then(|c| c.as_array())
521            .ok_or(SgrError::EmptyResponse)?;
522
523        for candidate in candidates {
524            let parts = candidate
525                .get("content")
526                .and_then(|c| c.get("parts"))
527                .and_then(|p| p.as_array());
528
529            if let Some(parts) = parts {
530                for part in parts {
531                    // Text part → structured output
532                    if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
533                        raw_text.push_str(text);
534                        if output.is_none() {
535                            match serde_json::from_str::<T>(text) {
536                                Ok(parsed) => output = Some(parsed),
537                                Err(e) => {
538                                    tracing::warn!(error = %e, "failed to parse structured output");
539                                }
540                            }
541                        }
542                    }
543
544                    // Function call part → tool call
545                    if let Some(fc) = part.get("functionCall") {
546                        let name = fc
547                            .get("name")
548                            .and_then(|n| n.as_str())
549                            .unwrap_or("unknown")
550                            .to_string();
551                        let args = fc.get("args").cloned().unwrap_or(json!({}));
552                        call_counter += 1;
553                        tool_calls.push(ToolCall {
554                            id: format!("call#{}#{}", name, call_counter),
555                            name,
556                            arguments: args,
557                        });
558                    }
559                }
560            }
561        }
562
563        if output.is_none() && tool_calls.is_empty() {
564            return Err(SgrError::EmptyResponse);
565        }
566
567        Ok(SgrResponse {
568            output,
569            tool_calls,
570            raw_text,
571            usage,
572            rate_limit,
573        })
574    }
575
576    fn extract_raw_text(&self, body: &Value) -> String {
577        let mut text = String::new();
578        if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
579            for candidate in candidates {
580                if let Some(parts) = candidate
581                    .get("content")
582                    .and_then(|c| c.get("parts"))
583                    .and_then(|p| p.as_array())
584                {
585                    for part in parts {
586                        if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
587                            text.push_str(t);
588                        }
589                    }
590                }
591            }
592        }
593        text
594    }
595
596    fn extract_tool_calls(&self, body: &Value) -> Vec<ToolCall> {
597        let mut calls = Vec::new();
598
599        if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
600            for candidate in candidates {
601                // Standard: functionCall in parts
602                if let Some(parts) = candidate
603                    .get("content")
604                    .and_then(|c| c.get("parts"))
605                    .and_then(|p| p.as_array())
606                {
607                    let mut call_counter = 0u32;
608                    for part in parts {
609                        if let Some(fc) = part.get("functionCall") {
610                            let name = fc
611                                .get("name")
612                                .and_then(|n| n.as_str())
613                                .unwrap_or("unknown")
614                                .to_string();
615                            let args = fc.get("args").cloned().unwrap_or(json!({}));
616                            call_counter += 1;
617                            calls.push(ToolCall {
618                                id: format!("call#{}#{}", name, call_counter),
619                                name,
620                                arguments: args,
621                            });
622                        }
623                    }
624                }
625
626                // Vertex AI fallback: tool call in finishMessage when no functionDeclarations
627                // Format: "Unexpected tool call: {\"tool_name\": \"bash\", \"command\": \"...\"}"
628                if calls.is_empty() {
629                    if let Some(msg) = candidate.get("finishMessage").and_then(|m| m.as_str()) {
630                        tracing::debug!(
631                            finish_message = msg,
632                            "parsing finishMessage for tool calls"
633                        );
634                        if let Some(json_start) = msg.find('{') {
635                            let json_str = &msg[json_start..];
636                            // Try to find matching closing brace for clean extraction
637                            let json_str = if let Some(end) = json_str.rfind('}') {
638                                &json_str[..=end]
639                            } else {
640                                json_str
641                            };
642                            if let Ok(tc_json) = serde_json::from_str::<Value>(json_str) {
643                                // Handle two formats:
644                                // 1. Flat: {"tool_name": "bash", "command": "..."}
645                                // 2. Actions array: {"actions": [{"tool_name": "read_file", "path": "..."}]}
646                                let items: Vec<Value> = if let Some(actions) =
647                                    tc_json.get("actions").and_then(|a| a.as_array())
648                                {
649                                    actions.clone()
650                                } else {
651                                    vec![tc_json]
652                                };
653                                for item in items {
654                                    let name = item
655                                        .get("tool_name")
656                                        .and_then(|n| n.as_str())
657                                        .unwrap_or("unknown")
658                                        .to_string();
659                                    let mut args = item.clone();
660                                    if let Some(obj) = args.as_object_mut() {
661                                        obj.remove("tool_name");
662                                    }
663                                    calls.push(ToolCall {
664                                        id: name.clone(),
665                                        name,
666                                        arguments: args,
667                                    });
668                                }
669                            }
670                        }
671                    }
672                }
673            }
674        }
675
676        calls
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use schemars::JsonSchema;
684    use serde::{Deserialize, Serialize};
685
686    #[derive(Debug, Serialize, Deserialize, JsonSchema)]
687    struct TestResponse {
688        answer: String,
689        confidence: f64,
690    }
691
692    #[test]
693    fn builds_request_with_tools_no_json_mode() {
694        let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
695        let messages = vec![Message::system("You are a helper."), Message::user("Hello")];
696        let tools = vec![crate::tool::tool::<TestResponse>("test_tool", "A test")];
697
698        let body = client
699            .build_request::<TestResponse>(&messages, &tools)
700            .unwrap();
701
702        // When tools are present, no JSON mode (Gemini doesn't support both)
703        assert!(body["generationConfig"]["responseSchema"].is_null());
704        assert!(body["generationConfig"]["responseMimeType"].is_null());
705
706        // Has tools + toolConfig
707        assert!(body["tools"][0]["functionDeclarations"].is_array());
708        assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
709
710        // Has system instruction
711        assert!(body["systemInstruction"]["parts"][0]["text"].is_string());
712
713        // Contents only has user (system extracted)
714        let contents = body["contents"].as_array().unwrap();
715        assert_eq!(contents.len(), 1);
716        assert_eq!(contents[0]["role"], "user");
717    }
718
719    #[test]
720    fn builds_request_without_tools_has_json_mode() {
721        let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
722        let messages = vec![Message::user("Hello")];
723
724        let body = client
725            .build_request::<TestResponse>(&messages, &[])
726            .unwrap();
727
728        // Without tools, JSON mode is enabled
729        assert!(body["generationConfig"]["responseSchema"].is_object());
730        assert_eq!(
731            body["generationConfig"]["responseMimeType"],
732            "application/json"
733        );
734        assert!(body["tools"].is_null());
735    }
736
737    #[test]
738    fn parses_text_response() {
739        let client = GeminiClient::from_api_key("test", "test");
740        let body = json!({
741            "candidates": [{
742                "content": {
743                    "parts": [{
744                        "text": "{\"answer\": \"42\", \"confidence\": 0.95}"
745                    }]
746                }
747            }],
748            "usageMetadata": {
749                "promptTokenCount": 10,
750                "candidatesTokenCount": 20,
751                "totalTokenCount": 30,
752            }
753        });
754
755        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
756        let output = result.output.unwrap();
757        assert_eq!(output.answer, "42");
758        assert_eq!(output.confidence, 0.95);
759        assert!(result.tool_calls.is_empty());
760        assert_eq!(result.usage.unwrap().total_tokens, 30);
761    }
762
763    #[test]
764    fn parses_function_call_response() {
765        let client = GeminiClient::from_api_key("test", "test");
766        let body = json!({
767            "candidates": [{
768                "content": {
769                    "parts": [{
770                        "functionCall": {
771                            "name": "test_tool",
772                            "args": {"input": "/video.mp4"}
773                        }
774                    }]
775                }
776            }]
777        });
778
779        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
780        assert!(result.output.is_none());
781        assert_eq!(result.tool_calls.len(), 1);
782        assert_eq!(result.tool_calls[0].name, "test_tool");
783        assert_eq!(result.tool_calls[0].arguments["input"], "/video.mp4");
784        // ID should be unique, not just the tool name
785        assert_eq!(result.tool_calls[0].id, "call#test_tool#1");
786    }
787
788    #[test]
789    fn multiple_function_calls_get_unique_ids() {
790        let client = GeminiClient::from_api_key("test", "test");
791        let body = json!({
792            "candidates": [{
793                "content": {
794                    "parts": [
795                        {"functionCall": {"name": "read_file", "args": {"path": "a.rs"}}},
796                        {"functionCall": {"name": "read_file", "args": {"path": "b.rs"}}},
797                        {"functionCall": {"name": "write_file", "args": {"path": "c.rs"}}},
798                    ]
799                }
800            }]
801        });
802
803        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
804        assert_eq!(result.tool_calls.len(), 3);
805        assert_eq!(result.tool_calls[0].id, "call#read_file#1");
806        assert_eq!(result.tool_calls[1].id, "call#read_file#2");
807        assert_eq!(result.tool_calls[2].id, "call#write_file#3");
808        // All IDs unique
809        let ids: std::collections::HashSet<_> = result.tool_calls.iter().map(|tc| &tc.id).collect();
810        assert_eq!(ids.len(), 3);
811    }
812
813    #[test]
814    fn func_name_extraction_from_call_id() {
815        let client = GeminiClient::from_api_key("test", "test");
816
817        // Build messages with tool results using our call ID format
818        // Consecutive tool messages should be grouped into one "function" turn
819        let messages = vec![
820            Message::user("test"),
821            Message::tool("call#write_file#1", "Wrote file"),
822            Message::tool("call#bash#2", "Output"),
823            Message::tool("call#my_custom_tool#10", "Result"),
824            Message::tool("old_format_id", "Legacy"), // fallback
825        ];
826
827        let contents = client.messages_to_contents(&messages);
828        // Index 0 = user, Index 1 = single function turn with 4 parts
829        assert_eq!(contents.len(), 2, "consecutive tools should be grouped");
830        assert_eq!(contents[1]["role"], "function");
831
832        let parts = contents[1]["parts"].as_array().unwrap();
833        assert_eq!(parts.len(), 4);
834        assert_eq!(parts[0]["functionResponse"]["name"], "write_file");
835        assert_eq!(parts[1]["functionResponse"]["name"], "bash");
836        assert_eq!(parts[2]["functionResponse"]["name"], "my_custom_tool");
837        assert_eq!(parts[3]["functionResponse"]["name"], "old_format_id");
838    }
839
840    #[test]
841    fn tool_messages_separated_by_model_not_grouped() {
842        let client = GeminiClient::from_api_key("test", "test");
843
844        // Tool messages separated by a model message should NOT be grouped
845        let messages = vec![
846            Message::user("test"),
847            Message::tool("call#read#1", "file A"),
848            Message::assistant("thinking..."),
849            Message::tool("call#read#2", "file B"),
850        ];
851
852        let contents = client.messages_to_contents(&messages);
853        // user, function(1 part), model, function(1 part)
854        assert_eq!(contents.len(), 4);
855        assert_eq!(contents[1]["parts"].as_array().unwrap().len(), 1);
856        assert_eq!(contents[3]["parts"].as_array().unwrap().len(), 1);
857    }
858}