Skip to main content

sgr_agent/
gemini.rs

1//! Gemini API client — structured output + function calling.
2//!
3//! Supports both Google AI Studio (API key) and Vertex AI (ADC).
4//!
5//! Two modes combined:
6//! - **Structured output**: `generationConfig.responseMimeType = "application/json"`
7//!   + `responseSchema` — forces model to return JSON matching the SGR envelope.
8//! - **Function calling**: `tools[].functionDeclarations` — model emits `functionCall`
9//!   parts that map to your Rust tool structs.
10//!
11//! The model can return BOTH structured text AND function calls in one response.
12
13use crate::schema::response_schema_for;
14use crate::tool::ToolDef;
15use crate::types::*;
16use schemars::JsonSchema;
17use serde::de::DeserializeOwned;
18use serde_json::{json, Value};
19
20/// Gemini API client.
21pub struct GeminiClient {
22    config: ProviderConfig,
23    http: reqwest::Client,
24}
25
26impl GeminiClient {
27    pub fn new(config: ProviderConfig) -> Self {
28        Self {
29            config,
30            http: reqwest::Client::new(),
31        }
32    }
33
34    /// Quick constructor for Google AI Studio (API key auth).
35    pub fn from_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
36        Self::new(ProviderConfig::gemini(api_key, model))
37    }
38
39    /// SGR call: structured output (typed response) + function calling (tools).
40    ///
41    /// Returns `SgrResponse<T>` where:
42    /// - `output`: parsed structured response (if model returned text)
43    /// - `tool_calls`: function calls (if model used tools)
44    ///
45    /// The model may return either or both.
46    pub async fn call<T: JsonSchema + DeserializeOwned>(
47        &self,
48        messages: &[Message],
49        tools: &[ToolDef],
50    ) -> Result<SgrResponse<T>, SgrError> {
51        let body = self.build_request::<T>(messages, tools)?;
52        let url = self.build_url();
53
54        tracing::debug!(url = %url, model = %self.config.model, "gemini_request");
55
56        let mut req = self.http.post(&url).json(&body);
57        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
58            req = req.bearer_auth(&self.config.api_key);
59        }
60        let response = req.send().await?;
61
62        let status = response.status().as_u16();
63        let headers = response.headers().clone();
64        if status != 200 {
65            let body = response.text().await.unwrap_or_default();
66            return Err(SgrError::from_response_parts(status, body, &headers));
67        }
68
69        let response_body: Value = response.json().await?;
70        let rate_limit = RateLimitInfo::from_headers(&headers);
71        self.parse_response(&response_body, rate_limit)
72    }
73
74    /// SGR call with structured output only (no tools).
75    ///
76    /// Shorthand for `call::<T>(messages, &[])`.
77    pub async fn structured<T: JsonSchema + DeserializeOwned>(
78        &self,
79        messages: &[Message],
80    ) -> Result<T, SgrError> {
81        let resp = self.call::<T>(messages, &[]).await?;
82        resp.output.ok_or(SgrError::EmptyResponse)
83    }
84
85    /// Flexible call: no structured output API, parse JSON from raw text.
86    ///
87    /// For use with text-only proxies (CLI proxy, Codex proxy) where
88    /// the model can't enforce JSON schema. Uses AnyOf cascade + coercion.
89    ///
90    /// Auto-injects JSON Schema into the system prompt so the model knows
91    /// the expected format (like BAML does).
92    pub async fn flexible<T: JsonSchema + DeserializeOwned>(
93        &self,
94        messages: &[Message],
95    ) -> Result<SgrResponse<T>, SgrError> {
96        // Send without responseSchema — plain text response
97        // Use text mode for tool messages (no functionDeclarations in this mode)
98        let contents = self.messages_to_contents_text(messages);
99        let mut system_instruction = self.extract_system(messages);
100
101        // Auto-inject schema hint into system prompt
102        let schema = response_schema_for::<T>();
103        let schema_hint = format!(
104            "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks.",
105            serde_json::to_string_pretty(&schema).unwrap_or_default()
106        );
107        system_instruction = Some(match system_instruction {
108            Some(s) => format!("{}{}", s, schema_hint),
109            None => schema_hint,
110        });
111
112        let mut gen_config = json!({
113            "temperature": self.config.temperature,
114        });
115        if let Some(max_tokens) = self.config.max_tokens {
116            gen_config["maxOutputTokens"] = json!(max_tokens);
117        }
118
119        let mut body = json!({
120            "contents": contents,
121            "generationConfig": gen_config,
122        });
123        if let Some(system) = system_instruction {
124            body["systemInstruction"] = json!({
125                "parts": [{"text": system}]
126            });
127        }
128
129        let url = self.build_url();
130        let mut req = self.http.post(&url).json(&body);
131        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
132            req = req.bearer_auth(&self.config.api_key);
133        }
134        let response = req.send().await?;
135        let status = response.status().as_u16();
136        let headers = response.headers().clone();
137        if status != 200 {
138            let body = response.text().await.unwrap_or_default();
139            return Err(SgrError::from_response_parts(status, body, &headers));
140        }
141
142        let response_body: Value = response.json().await?;
143        let rate_limit = RateLimitInfo::from_headers(&headers);
144
145        // Extract raw text
146        let raw_text = self.extract_raw_text(&response_body);
147        if raw_text.trim().is_empty() {
148            // Log finish reason and response parts for debugging
149            if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
150                let reason = candidate
151                    .get("finishReason")
152                    .and_then(|r| r.as_str())
153                    .unwrap_or("unknown");
154                tracing::warn!(
155                    finish_reason = reason,
156                    has_parts = candidate
157                        .get("content")
158                        .and_then(|c| c.get("parts"))
159                        .is_some(),
160                    "empty raw_text from Gemini"
161                );
162            }
163        }
164        let usage = response_body.get("usageMetadata").and_then(|u| {
165            Some(Usage {
166                prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
167                completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
168                total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
169            })
170        });
171
172        // Extract native function calls (Gemini may use functionCall parts
173        // even without explicit functionDeclarations — especially newer models).
174        let tool_calls = self.extract_tool_calls(&response_body);
175
176        // Flexible parse with coercion.
177        // If parsing fails, return output=None with raw_text preserved
178        // so callers can implement fallback logic (e.g. wrap in finish tool).
179        let output = crate::flexible_parser::parse_flexible_coerced::<T>(&raw_text)
180            .map(|r| r.value)
181            .ok();
182
183        if output.is_none() && raw_text.trim().is_empty() && tool_calls.is_empty() {
184            // Log raw response for debugging
185            let parts_summary = response_body
186                .get("candidates")
187                .and_then(|c| c.get(0))
188                .and_then(|c| c.get("content"))
189                .and_then(|c| c.get("parts"))
190                .and_then(|p| p.as_array())
191                .map(|parts| {
192                    parts
193                        .iter()
194                        .map(|p| {
195                            if p.get("text").is_some() {
196                                "text".to_string()
197                            } else if p.get("functionCall").is_some() {
198                                format!(
199                                    "functionCall:{}",
200                                    p["functionCall"]["name"].as_str().unwrap_or("?")
201                                )
202                            } else {
203                                format!("unknown:{}", p)
204                            }
205                        })
206                        .collect::<Vec<_>>()
207                        .join(", ")
208                })
209                .unwrap_or_else(|| "no parts".into());
210            // Log full candidate for debugging
211            let candidate_json = response_body
212                .get("candidates")
213                .and_then(|c| c.get(0))
214                .map(|c| serde_json::to_string_pretty(c).unwrap_or_default())
215                .unwrap_or_else(|| "no candidates".into());
216            tracing::error!(
217                parts = parts_summary,
218                candidate = candidate_json.as_str(),
219                "SGR empty response"
220            );
221            return Err(SgrError::Schema(format!(
222                "Empty response from model (parts: {})",
223                parts_summary
224            )));
225        }
226
227        Ok(SgrResponse {
228            output,
229            tool_calls,
230            raw_text,
231            usage,
232            rate_limit,
233        })
234    }
235
236    /// Tool-only call: no structured output schema, just function calling.
237    ///
238    /// Returns raw tool calls.
239    pub async fn tools_call(
240        &self,
241        messages: &[Message],
242        tools: &[ToolDef],
243    ) -> Result<Vec<ToolCall>, SgrError> {
244        let body = self.build_tools_only_request(messages, tools)?;
245        let url = self.build_url();
246
247        let mut req = self.http.post(&url).json(&body);
248        if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
249            req = req.bearer_auth(&self.config.api_key);
250        }
251        let response = req.send().await?;
252        let status = response.status().as_u16();
253        let headers = response.headers().clone();
254        if status != 200 {
255            let body = response.text().await.unwrap_or_default();
256            return Err(SgrError::from_response_parts(status, body, &headers));
257        }
258
259        let response_body: Value = response.json().await?;
260        Ok(self.extract_tool_calls(&response_body))
261    }
262
263    // --- Private ---
264
265    fn build_url(&self) -> String {
266        if let Some(project_id) = &self.config.project_id {
267            // Vertex AI
268            let location = self.config.location.as_deref().unwrap_or("global");
269            let host = if location == "global" {
270                "aiplatform.googleapis.com".to_string()
271            } else {
272                format!("{location}-aiplatform.googleapis.com")
273            };
274            format!(
275                "https://{host}/v1/projects/{project_id}/locations/{location}/publishers/google/models/{}:generateContent",
276                self.config.model
277            )
278        } else {
279            // Google AI Studio
280            format!(
281                "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
282                self.config.model, self.config.api_key
283            )
284        }
285    }
286
287    fn build_request<T: JsonSchema>(
288        &self,
289        messages: &[Message],
290        tools: &[ToolDef],
291    ) -> Result<Value, SgrError> {
292        // Use functionResponse format only when tools are present
293        let contents = if tools.is_empty() {
294            self.messages_to_contents_text(messages)
295        } else {
296            self.messages_to_contents(messages)
297        };
298        let system_instruction = self.extract_system(messages);
299
300        // When using function calling, Gemini doesn't support responseMimeType + functionDeclarations.
301        // Use structured output (JSON mode) only when there are no tools.
302        let mut gen_config = json!({
303            "temperature": self.config.temperature,
304        });
305
306        if tools.is_empty() {
307            gen_config["responseMimeType"] = json!("application/json");
308            gen_config["responseSchema"] = response_schema_for::<T>();
309        }
310
311        if let Some(max_tokens) = self.config.max_tokens {
312            gen_config["maxOutputTokens"] = json!(max_tokens);
313        }
314
315        let mut body = json!({
316            "contents": contents,
317            "generationConfig": gen_config,
318        });
319
320        if let Some(system) = system_instruction {
321            body["systemInstruction"] = json!({
322                "parts": [{"text": system}]
323            });
324        }
325
326        if !tools.is_empty() {
327            let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
328            body["tools"] = json!([{
329                "functionDeclarations": function_declarations,
330            }]);
331            body["toolConfig"] = json!({
332                "functionCallingConfig": {
333                    "mode": "AUTO"
334                }
335            });
336        }
337
338        Ok(body)
339    }
340
341    fn build_tools_only_request(
342        &self,
343        messages: &[Message],
344        tools: &[ToolDef],
345    ) -> Result<Value, SgrError> {
346        let contents = self.messages_to_contents(messages);
347        let system_instruction = self.extract_system(messages);
348
349        let mut gen_config = json!({
350            "temperature": self.config.temperature,
351        });
352        if let Some(max_tokens) = self.config.max_tokens {
353            gen_config["maxOutputTokens"] = json!(max_tokens);
354        }
355
356        let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
357
358        let mut body = json!({
359            "contents": contents,
360            "generationConfig": gen_config,
361            "tools": [{
362                "functionDeclarations": function_declarations,
363            }],
364            "toolConfig": {
365                "functionCallingConfig": {
366                    "mode": "ANY"
367                }
368            }
369        });
370
371        if let Some(system) = system_instruction {
372            body["systemInstruction"] = json!({
373                "parts": [{"text": system}]
374            });
375        }
376
377        Ok(body)
378    }
379
380    /// Convert messages to Gemini contents format.
381    ///
382    /// When `use_function_response` is true, Tool messages become `functionResponse` parts
383    /// (for native function calling mode). When false, they become user text messages
384    /// (for structured output / flexible mode where no function declarations are sent).
385    fn messages_to_contents(&self, messages: &[Message]) -> Vec<Value> {
386        self.messages_to_contents_inner(messages, true)
387    }
388
389    fn messages_to_contents_text(&self, messages: &[Message]) -> Vec<Value> {
390        self.messages_to_contents_inner(messages, false)
391    }
392
393    fn messages_to_contents_inner(
394        &self,
395        messages: &[Message],
396        use_function_response: bool,
397    ) -> Vec<Value> {
398        let mut contents = Vec::new();
399
400        let mut i = 0;
401        while i < messages.len() {
402            let msg = &messages[i];
403            match msg.role {
404                Role::System => {
405                    i += 1;
406                } // handled separately via systemInstruction
407                Role::User => {
408                    let mut parts = vec![json!({"text": msg.content})];
409                    for img in &msg.images {
410                        parts.push(json!({
411                            "inlineData": {
412                                "mimeType": img.mime_type,
413                                "data": img.data,
414                            }
415                        }));
416                    }
417                    contents.push(json!({ "role": "user", "parts": parts }));
418                    i += 1;
419                }
420                Role::Assistant => {
421                    if use_function_response && !msg.tool_calls.is_empty() {
422                        // Model turn with function calls — include functionCall parts
423                        let mut parts: Vec<Value> = Vec::new();
424                        if !msg.content.is_empty() {
425                            parts.push(json!({"text": msg.content}));
426                        }
427                        for tc in &msg.tool_calls {
428                            parts.push(json!({
429                                "functionCall": {
430                                    "name": tc.name,
431                                    "args": tc.arguments
432                                }
433                            }));
434                        }
435                        contents.push(json!({
436                            "role": "model",
437                            "parts": parts
438                        }));
439                    } else {
440                        contents.push(json!({
441                            "role": "model",
442                            "parts": [{"text": msg.content}]
443                        }));
444                    }
445                    i += 1;
446                }
447                Role::Tool => {
448                    if use_function_response {
449                        // Gemini requires ALL functionResponses for one turn in a SINGLE
450                        // "function" content entry. Collect consecutive Tool messages.
451                        let mut parts = Vec::new();
452                        let mut pending_images: Vec<(&str, &[crate::types::ImagePart])> =
453                            Vec::new();
454                        while i < messages.len() && messages[i].role == Role::Tool {
455                            let tool_msg = &messages[i];
456                            let call_id = tool_msg.tool_call_id.as_deref().unwrap_or("unknown");
457                            let func_name = match call_id.split('#').collect::<Vec<_>>().as_slice()
458                            {
459                                ["call", name, _counter] => *name,
460                                _ => call_id,
461                            };
462                            parts.push(json!({
463                                "functionResponse": {
464                                    "name": func_name,
465                                    "response": {
466                                        "content": tool_msg.content,
467                                    }
468                                }
469                            }));
470                            if !tool_msg.images.is_empty() {
471                                pending_images.push((call_id, &tool_msg.images));
472                            }
473                            i += 1;
474                        }
475                        contents.push(json!({
476                            "role": "function",
477                            "parts": parts
478                        }));
479                        // Gemini doesn't support inlineData inside functionResponse,
480                        // so attach images as a follow-up user message.
481                        for (call_id, images) in pending_images {
482                            let mut img_parts: Vec<Value> = vec![
483                                json!({"text": format!("[Images from {} tool result]", call_id)}),
484                            ];
485                            for img in images {
486                                img_parts.push(json!({
487                                    "inlineData": {
488                                        "mimeType": img.mime_type,
489                                        "data": img.data,
490                                    }
491                                }));
492                            }
493                            contents.push(json!({ "role": "user", "parts": img_parts }));
494                        }
495                    } else {
496                        // Text mode — convert tool results to user messages
497                        let call_id = msg.tool_call_id.as_deref().unwrap_or("tool");
498                        let mut parts: Vec<Value> =
499                            vec![json!({"text": format!("[{}] {}", call_id, msg.content)})];
500                        for img in &msg.images {
501                            parts.push(json!({
502                                "inlineData": {
503                                    "mimeType": img.mime_type,
504                                    "data": img.data,
505                                }
506                            }));
507                        }
508                        contents.push(json!({
509                            "role": "user",
510                            "parts": parts
511                        }));
512                        i += 1;
513                    }
514                }
515            }
516        }
517
518        contents
519    }
520
521    fn extract_system(&self, messages: &[Message]) -> Option<String> {
522        let system_parts: Vec<&str> = messages
523            .iter()
524            .filter(|m| m.role == Role::System)
525            .map(|m| m.content.as_str())
526            .collect();
527
528        if system_parts.is_empty() {
529            None
530        } else {
531            Some(system_parts.join("\n\n"))
532        }
533    }
534
535    fn parse_response<T: DeserializeOwned>(
536        &self,
537        body: &Value,
538        rate_limit: Option<RateLimitInfo>,
539    ) -> Result<SgrResponse<T>, SgrError> {
540        let mut output: Option<T> = None;
541        let mut tool_calls = Vec::new();
542        let mut raw_text = String::new();
543        let mut call_counter: u32 = 0;
544
545        // Parse usage
546        let usage = body.get("usageMetadata").and_then(|u| {
547            Some(Usage {
548                prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
549                completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
550                total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
551            })
552        });
553
554        // Extract from candidates
555        let candidates = body
556            .get("candidates")
557            .and_then(|c| c.as_array())
558            .ok_or(SgrError::EmptyResponse)?;
559
560        for candidate in candidates {
561            let parts = candidate
562                .get("content")
563                .and_then(|c| c.get("parts"))
564                .and_then(|p| p.as_array());
565
566            if let Some(parts) = parts {
567                for part in parts {
568                    // Text part → structured output
569                    if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
570                        raw_text.push_str(text);
571                        if output.is_none() {
572                            match serde_json::from_str::<T>(text) {
573                                Ok(parsed) => output = Some(parsed),
574                                Err(e) => {
575                                    tracing::warn!(error = %e, "failed to parse structured output");
576                                }
577                            }
578                        }
579                    }
580
581                    // Function call part → tool call
582                    if let Some(fc) = part.get("functionCall") {
583                        let name = fc
584                            .get("name")
585                            .and_then(|n| n.as_str())
586                            .unwrap_or("unknown")
587                            .to_string();
588                        let args = fc.get("args").cloned().unwrap_or(json!({}));
589                        call_counter += 1;
590                        tool_calls.push(ToolCall {
591                            id: format!("call#{}#{}", name, call_counter),
592                            name,
593                            arguments: args,
594                        });
595                    }
596                }
597            }
598        }
599
600        if output.is_none() && tool_calls.is_empty() {
601            return Err(SgrError::EmptyResponse);
602        }
603
604        Ok(SgrResponse {
605            output,
606            tool_calls,
607            raw_text,
608            usage,
609            rate_limit,
610        })
611    }
612
613    fn extract_raw_text(&self, body: &Value) -> String {
614        let mut text = String::new();
615        if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
616            for candidate in candidates {
617                if let Some(parts) = candidate
618                    .get("content")
619                    .and_then(|c| c.get("parts"))
620                    .and_then(|p| p.as_array())
621                {
622                    for part in parts {
623                        if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
624                            text.push_str(t);
625                        }
626                    }
627                }
628            }
629        }
630        text
631    }
632
633    fn extract_tool_calls(&self, body: &Value) -> Vec<ToolCall> {
634        let mut calls = Vec::new();
635
636        if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
637            for candidate in candidates {
638                // Standard: functionCall in parts
639                if let Some(parts) = candidate
640                    .get("content")
641                    .and_then(|c| c.get("parts"))
642                    .and_then(|p| p.as_array())
643                {
644                    let mut call_counter = 0u32;
645                    for part in parts {
646                        if let Some(fc) = part.get("functionCall") {
647                            let name = fc
648                                .get("name")
649                                .and_then(|n| n.as_str())
650                                .unwrap_or("unknown")
651                                .to_string();
652                            let args = fc.get("args").cloned().unwrap_or(json!({}));
653                            call_counter += 1;
654                            calls.push(ToolCall {
655                                id: format!("call#{}#{}", name, call_counter),
656                                name,
657                                arguments: args,
658                            });
659                        }
660                    }
661                }
662
663                // Vertex AI fallback: tool call in finishMessage when no functionDeclarations
664                // Format: "Unexpected tool call: {\"tool_name\": \"bash\", \"command\": \"...\"}"
665                if calls.is_empty() {
666                    if let Some(msg) = candidate.get("finishMessage").and_then(|m| m.as_str()) {
667                        tracing::debug!(
668                            finish_message = msg,
669                            "parsing finishMessage for tool calls"
670                        );
671                        if let Some(json_start) = msg.find('{') {
672                            let json_str = &msg[json_start..];
673                            // Try to find matching closing brace for clean extraction
674                            let json_str = if let Some(end) = json_str.rfind('}') {
675                                &json_str[..=end]
676                            } else {
677                                json_str
678                            };
679                            if let Ok(tc_json) = serde_json::from_str::<Value>(json_str) {
680                                // Handle two formats:
681                                // 1. Flat: {"tool_name": "bash", "command": "..."}
682                                // 2. Actions array: {"actions": [{"tool_name": "read_file", "path": "..."}]}
683                                let items: Vec<Value> = if let Some(actions) =
684                                    tc_json.get("actions").and_then(|a| a.as_array())
685                                {
686                                    actions.clone()
687                                } else {
688                                    vec![tc_json]
689                                };
690                                for item in items {
691                                    let name = item
692                                        .get("tool_name")
693                                        .and_then(|n| n.as_str())
694                                        .unwrap_or("unknown")
695                                        .to_string();
696                                    let mut args = item.clone();
697                                    if let Some(obj) = args.as_object_mut() {
698                                        obj.remove("tool_name");
699                                    }
700                                    calls.push(ToolCall {
701                                        id: name.clone(),
702                                        name,
703                                        arguments: args,
704                                    });
705                                }
706                            }
707                        }
708                    }
709                }
710            }
711        }
712
713        calls
714    }
715}
716
717#[cfg(test)]
718mod tests {
719    use super::*;
720    use schemars::JsonSchema;
721    use serde::{Deserialize, Serialize};
722
723    #[derive(Debug, Serialize, Deserialize, JsonSchema)]
724    struct TestResponse {
725        answer: String,
726        confidence: f64,
727    }
728
729    #[test]
730    fn builds_request_with_tools_no_json_mode() {
731        let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
732        let messages = vec![Message::system("You are a helper."), Message::user("Hello")];
733        let tools = vec![crate::tool::tool::<TestResponse>("test_tool", "A test")];
734
735        let body = client
736            .build_request::<TestResponse>(&messages, &tools)
737            .unwrap();
738
739        // When tools are present, no JSON mode (Gemini doesn't support both)
740        assert!(body["generationConfig"]["responseSchema"].is_null());
741        assert!(body["generationConfig"]["responseMimeType"].is_null());
742
743        // Has tools + toolConfig
744        assert!(body["tools"][0]["functionDeclarations"].is_array());
745        assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
746
747        // Has system instruction
748        assert!(body["systemInstruction"]["parts"][0]["text"].is_string());
749
750        // Contents only has user (system extracted)
751        let contents = body["contents"].as_array().unwrap();
752        assert_eq!(contents.len(), 1);
753        assert_eq!(contents[0]["role"], "user");
754    }
755
756    #[test]
757    fn builds_request_without_tools_has_json_mode() {
758        let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
759        let messages = vec![Message::user("Hello")];
760
761        let body = client
762            .build_request::<TestResponse>(&messages, &[])
763            .unwrap();
764
765        // Without tools, JSON mode is enabled
766        assert!(body["generationConfig"]["responseSchema"].is_object());
767        assert_eq!(
768            body["generationConfig"]["responseMimeType"],
769            "application/json"
770        );
771        assert!(body["tools"].is_null());
772    }
773
774    #[test]
775    fn parses_text_response() {
776        let client = GeminiClient::from_api_key("test", "test");
777        let body = json!({
778            "candidates": [{
779                "content": {
780                    "parts": [{
781                        "text": "{\"answer\": \"42\", \"confidence\": 0.95}"
782                    }]
783                }
784            }],
785            "usageMetadata": {
786                "promptTokenCount": 10,
787                "candidatesTokenCount": 20,
788                "totalTokenCount": 30,
789            }
790        });
791
792        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
793        let output = result.output.unwrap();
794        assert_eq!(output.answer, "42");
795        assert_eq!(output.confidence, 0.95);
796        assert!(result.tool_calls.is_empty());
797        assert_eq!(result.usage.unwrap().total_tokens, 30);
798    }
799
800    #[test]
801    fn parses_function_call_response() {
802        let client = GeminiClient::from_api_key("test", "test");
803        let body = json!({
804            "candidates": [{
805                "content": {
806                    "parts": [{
807                        "functionCall": {
808                            "name": "test_tool",
809                            "args": {"input": "/video.mp4"}
810                        }
811                    }]
812                }
813            }]
814        });
815
816        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
817        assert!(result.output.is_none());
818        assert_eq!(result.tool_calls.len(), 1);
819        assert_eq!(result.tool_calls[0].name, "test_tool");
820        assert_eq!(result.tool_calls[0].arguments["input"], "/video.mp4");
821        // ID should be unique, not just the tool name
822        assert_eq!(result.tool_calls[0].id, "call#test_tool#1");
823    }
824
825    #[test]
826    fn multiple_function_calls_get_unique_ids() {
827        let client = GeminiClient::from_api_key("test", "test");
828        let body = json!({
829            "candidates": [{
830                "content": {
831                    "parts": [
832                        {"functionCall": {"name": "read_file", "args": {"path": "a.rs"}}},
833                        {"functionCall": {"name": "read_file", "args": {"path": "b.rs"}}},
834                        {"functionCall": {"name": "write_file", "args": {"path": "c.rs"}}},
835                    ]
836                }
837            }]
838        });
839
840        let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
841        assert_eq!(result.tool_calls.len(), 3);
842        assert_eq!(result.tool_calls[0].id, "call#read_file#1");
843        assert_eq!(result.tool_calls[1].id, "call#read_file#2");
844        assert_eq!(result.tool_calls[2].id, "call#write_file#3");
845        // All IDs unique
846        let ids: std::collections::HashSet<_> = result.tool_calls.iter().map(|tc| &tc.id).collect();
847        assert_eq!(ids.len(), 3);
848    }
849
850    #[test]
851    fn func_name_extraction_from_call_id() {
852        let client = GeminiClient::from_api_key("test", "test");
853
854        // Build messages with tool results using our call ID format
855        // Consecutive tool messages should be grouped into one "function" turn
856        let messages = vec![
857            Message::user("test"),
858            Message::tool("call#write_file#1", "Wrote file"),
859            Message::tool("call#bash#2", "Output"),
860            Message::tool("call#my_custom_tool#10", "Result"),
861            Message::tool("old_format_id", "Legacy"), // fallback
862        ];
863
864        let contents = client.messages_to_contents(&messages);
865        // Index 0 = user, Index 1 = single function turn with 4 parts
866        assert_eq!(contents.len(), 2, "consecutive tools should be grouped");
867        assert_eq!(contents[1]["role"], "function");
868
869        let parts = contents[1]["parts"].as_array().unwrap();
870        assert_eq!(parts.len(), 4);
871        assert_eq!(parts[0]["functionResponse"]["name"], "write_file");
872        assert_eq!(parts[1]["functionResponse"]["name"], "bash");
873        assert_eq!(parts[2]["functionResponse"]["name"], "my_custom_tool");
874        assert_eq!(parts[3]["functionResponse"]["name"], "old_format_id");
875    }
876
877    #[test]
878    fn tool_messages_separated_by_model_not_grouped() {
879        let client = GeminiClient::from_api_key("test", "test");
880
881        // Tool messages separated by a model message should NOT be grouped
882        let messages = vec![
883            Message::user("test"),
884            Message::tool("call#read#1", "file A"),
885            Message::assistant("thinking..."),
886            Message::tool("call#read#2", "file B"),
887        ];
888
889        let contents = client.messages_to_contents(&messages);
890        // user, function(1 part), model, function(1 part)
891        assert_eq!(contents.len(), 4);
892        assert_eq!(contents[1]["parts"].as_array().unwrap().len(), 1);
893        assert_eq!(contents[3]["parts"].as_array().unwrap().len(), 1);
894    }
895}