1use crate::schema::response_schema_for;
14use crate::tool::ToolDef;
15use crate::types::*;
16use schemars::JsonSchema;
17use serde::de::DeserializeOwned;
18use serde_json::{Value, json};
19
20pub struct GeminiClient {
22 config: ProviderConfig,
23 http: reqwest::Client,
24}
25
26impl GeminiClient {
27 pub fn new(config: ProviderConfig) -> Self {
28 Self {
29 config,
30 http: reqwest::Client::new(),
31 }
32 }
33
34 pub fn from_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
36 Self::new(ProviderConfig::gemini(api_key, model))
37 }
38
39 pub async fn call<T: JsonSchema + DeserializeOwned>(
47 &self,
48 messages: &[Message],
49 tools: &[ToolDef],
50 ) -> Result<SgrResponse<T>, SgrError> {
51 let body = self.build_request::<T>(messages, tools)?;
52 let url = self.build_url();
53
54 tracing::debug!(url = %url, model = %self.config.model, "gemini_request");
55
56 let mut req = self.http.post(&url).json(&body);
57 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
58 req = req.bearer_auth(&self.config.api_key);
59 }
60 let response = req.send().await?;
61
62 let status = response.status().as_u16();
63 let headers = response.headers().clone();
64 if status != 200 {
65 let body = response.text().await.unwrap_or_default();
66 return Err(SgrError::from_response_parts(status, body, &headers));
67 }
68
69 let response_body: Value = response.json().await?;
70 let rate_limit = RateLimitInfo::from_headers(&headers);
71 self.parse_response(&response_body, rate_limit)
72 }
73
74 pub async fn structured<T: JsonSchema + DeserializeOwned>(
78 &self,
79 messages: &[Message],
80 ) -> Result<T, SgrError> {
81 let resp = self.call::<T>(messages, &[]).await?;
82 resp.output.ok_or(SgrError::EmptyResponse)
83 }
84
85 pub async fn flexible<T: JsonSchema + DeserializeOwned>(
93 &self,
94 messages: &[Message],
95 ) -> Result<SgrResponse<T>, SgrError> {
96 let contents = self.messages_to_contents_text(messages);
99 let mut system_instruction = self.extract_system(messages);
100
101 let schema = response_schema_for::<T>();
103 let schema_hint = format!(
104 "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks.",
105 serde_json::to_string_pretty(&schema).unwrap_or_default()
106 );
107 system_instruction = Some(match system_instruction {
108 Some(s) => format!("{}{}", s, schema_hint),
109 None => schema_hint,
110 });
111
112 let mut gen_config = json!({
113 "temperature": self.config.temperature,
114 });
115 if let Some(max_tokens) = self.config.max_tokens {
116 gen_config["maxOutputTokens"] = json!(max_tokens);
117 }
118
119 let mut body = json!({
120 "contents": contents,
121 "generationConfig": gen_config,
122 });
123 if let Some(system) = system_instruction {
124 body["systemInstruction"] = json!({
125 "parts": [{"text": system}]
126 });
127 }
128
129 let url = self.build_url();
130 let mut req = self.http.post(&url).json(&body);
131 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
132 req = req.bearer_auth(&self.config.api_key);
133 }
134 let response = req.send().await?;
135 let status = response.status().as_u16();
136 let headers = response.headers().clone();
137 if status != 200 {
138 let body = response.text().await.unwrap_or_default();
139 return Err(SgrError::from_response_parts(status, body, &headers));
140 }
141
142 let response_body: Value = response.json().await?;
143 let rate_limit = RateLimitInfo::from_headers(&headers);
144
145 if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
147 let finish_reason = candidate
148 .get("finishReason")
149 .and_then(|r| r.as_str())
150 .unwrap_or("");
151 if finish_reason == "MAX_TOKENS" {
152 let partial = self.extract_raw_text(&response_body);
153 return Err(SgrError::MaxOutputTokens {
154 partial_content: partial,
155 });
156 }
157 }
158
159 let raw_text = self.extract_raw_text(&response_body);
161 if raw_text.trim().is_empty() {
162 if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
164 let reason = candidate
165 .get("finishReason")
166 .and_then(|r| r.as_str())
167 .unwrap_or("unknown");
168 tracing::warn!(
169 finish_reason = reason,
170 has_parts = candidate
171 .get("content")
172 .and_then(|c| c.get("parts"))
173 .is_some(),
174 "empty raw_text from Gemini"
175 );
176 }
177 }
178 let usage = response_body.get("usageMetadata").and_then(|u| {
179 Some(Usage {
180 prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
181 completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
182 total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
183 })
184 });
185
186 let tool_calls = self.extract_tool_calls(&response_body);
189
190 let output = crate::flexible_parser::parse_flexible_coerced::<T>(&raw_text)
194 .map(|r| r.value)
195 .ok();
196
197 if output.is_none() && raw_text.trim().is_empty() && tool_calls.is_empty() {
198 let parts_summary = response_body
200 .get("candidates")
201 .and_then(|c| c.get(0))
202 .and_then(|c| c.get("content"))
203 .and_then(|c| c.get("parts"))
204 .and_then(|p| p.as_array())
205 .map(|parts| {
206 parts
207 .iter()
208 .map(|p| {
209 if p.get("text").is_some() {
210 "text".to_string()
211 } else if p.get("functionCall").is_some() {
212 format!(
213 "functionCall:{}",
214 p["functionCall"]["name"].as_str().unwrap_or("?")
215 )
216 } else {
217 format!("unknown:{}", p)
218 }
219 })
220 .collect::<Vec<_>>()
221 .join(", ")
222 })
223 .unwrap_or_else(|| "no parts".into());
224 let candidate_json = response_body
226 .get("candidates")
227 .and_then(|c| c.get(0))
228 .map(|c| serde_json::to_string_pretty(c).unwrap_or_default())
229 .unwrap_or_else(|| "no candidates".into());
230 tracing::error!(
231 parts = parts_summary,
232 candidate = candidate_json.as_str(),
233 "SGR empty response"
234 );
235 return Err(SgrError::Schema(format!(
236 "Empty response from model (parts: {})",
237 parts_summary
238 )));
239 }
240
241 Ok(SgrResponse {
242 output,
243 tool_calls,
244 raw_text,
245 usage,
246 rate_limit,
247 })
248 }
249
250 pub async fn tools_call(
254 &self,
255 messages: &[Message],
256 tools: &[ToolDef],
257 ) -> Result<Vec<ToolCall>, SgrError> {
258 let body = self.build_tools_only_request(messages, tools)?;
259 let url = self.build_url();
260
261 let mut req = self.http.post(&url).json(&body);
262 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
263 req = req.bearer_auth(&self.config.api_key);
264 }
265 let response = req.send().await?;
266 let status = response.status().as_u16();
267 let headers = response.headers().clone();
268 if status != 200 {
269 let body = response.text().await.unwrap_or_default();
270 return Err(SgrError::from_response_parts(status, body, &headers));
271 }
272
273 let response_body: Value = response.json().await?;
274
275 if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
277 let finish_reason = candidate
278 .get("finishReason")
279 .and_then(|r| r.as_str())
280 .unwrap_or("");
281 if finish_reason == "MAX_TOKENS" {
282 let partial = self.extract_raw_text(&response_body);
283 return Err(SgrError::MaxOutputTokens {
284 partial_content: partial,
285 });
286 }
287 }
288
289 Ok(self.extract_tool_calls(&response_body))
290 }
291
292 fn build_url(&self) -> String {
295 if let Some(project_id) = &self.config.project_id {
296 let location = self.config.location.as_deref().unwrap_or("global");
298 let host = if location == "global" {
299 "aiplatform.googleapis.com".to_string()
300 } else {
301 format!("{location}-aiplatform.googleapis.com")
302 };
303 format!(
304 "https://{host}/v1/projects/{project_id}/locations/{location}/publishers/google/models/{}:generateContent",
305 self.config.model
306 )
307 } else {
308 format!(
310 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
311 self.config.model, self.config.api_key
312 )
313 }
314 }
315
316 fn build_request<T: JsonSchema>(
317 &self,
318 messages: &[Message],
319 tools: &[ToolDef],
320 ) -> Result<Value, SgrError> {
321 let contents = if tools.is_empty() {
323 self.messages_to_contents_text(messages)
324 } else {
325 self.messages_to_contents(messages)
326 };
327 let system_instruction = self.extract_system(messages);
328
329 let mut gen_config = json!({
332 "temperature": self.config.temperature,
333 });
334
335 if tools.is_empty() {
336 gen_config["responseMimeType"] = json!("application/json");
337 gen_config["responseSchema"] = response_schema_for::<T>();
338 }
339
340 if let Some(max_tokens) = self.config.max_tokens {
341 gen_config["maxOutputTokens"] = json!(max_tokens);
342 }
343
344 let mut body = json!({
345 "contents": contents,
346 "generationConfig": gen_config,
347 });
348
349 if let Some(system) = system_instruction {
350 body["systemInstruction"] = json!({
351 "parts": [{"text": system}]
352 });
353 }
354
355 if !tools.is_empty() {
356 let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
357 body["tools"] = json!([{
358 "functionDeclarations": function_declarations,
359 }]);
360 body["toolConfig"] = json!({
361 "functionCallingConfig": {
362 "mode": "AUTO"
363 }
364 });
365 }
366
367 Ok(body)
368 }
369
370 fn build_tools_only_request(
371 &self,
372 messages: &[Message],
373 tools: &[ToolDef],
374 ) -> Result<Value, SgrError> {
375 let contents = self.messages_to_contents(messages);
376 let system_instruction = self.extract_system(messages);
377
378 let mut gen_config = json!({
379 "temperature": self.config.temperature,
380 });
381 if let Some(max_tokens) = self.config.max_tokens {
382 gen_config["maxOutputTokens"] = json!(max_tokens);
383 }
384
385 let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
386
387 let mut body = json!({
388 "contents": contents,
389 "generationConfig": gen_config,
390 "tools": [{
391 "functionDeclarations": function_declarations,
392 }],
393 "toolConfig": {
394 "functionCallingConfig": {
395 "mode": "ANY"
396 }
397 }
398 });
399
400 if let Some(system) = system_instruction {
401 body["systemInstruction"] = json!({
402 "parts": [{"text": system}]
403 });
404 }
405
406 Ok(body)
407 }
408
409 fn messages_to_contents(&self, messages: &[Message]) -> Vec<Value> {
415 self.messages_to_contents_inner(messages, true)
416 }
417
418 fn messages_to_contents_text(&self, messages: &[Message]) -> Vec<Value> {
419 self.messages_to_contents_inner(messages, false)
420 }
421
422 fn messages_to_contents_inner(
423 &self,
424 messages: &[Message],
425 use_function_response: bool,
426 ) -> Vec<Value> {
427 let mut contents = Vec::new();
428
429 let mut i = 0;
430 while i < messages.len() {
431 let msg = &messages[i];
432 match msg.role {
433 Role::System => {
434 i += 1;
435 } Role::User => {
437 let mut parts = vec![json!({"text": msg.content})];
438 for img in &msg.images {
439 parts.push(json!({
440 "inlineData": {
441 "mimeType": img.mime_type,
442 "data": img.data,
443 }
444 }));
445 }
446 contents.push(json!({ "role": "user", "parts": parts }));
447 i += 1;
448 }
449 Role::Assistant => {
450 if use_function_response && !msg.tool_calls.is_empty() {
451 let mut parts: Vec<Value> = Vec::new();
453 if !msg.content.is_empty() {
454 parts.push(json!({"text": msg.content}));
455 }
456 for tc in &msg.tool_calls {
457 parts.push(json!({
458 "functionCall": {
459 "name": tc.name,
460 "args": tc.arguments
461 }
462 }));
463 }
464 contents.push(json!({
465 "role": "model",
466 "parts": parts
467 }));
468 } else {
469 contents.push(json!({
470 "role": "model",
471 "parts": [{"text": msg.content}]
472 }));
473 }
474 i += 1;
475 }
476 Role::Tool => {
477 if use_function_response {
478 let mut parts = Vec::new();
481 let mut pending_images: Vec<(&str, &[crate::types::ImagePart])> =
482 Vec::new();
483 while i < messages.len() && messages[i].role == Role::Tool {
484 let tool_msg = &messages[i];
485 let call_id = tool_msg.tool_call_id.as_deref().unwrap_or("unknown");
486 let func_name = match call_id.split('#').collect::<Vec<_>>().as_slice()
487 {
488 ["call", name, _counter] => *name,
489 _ => call_id,
490 };
491 parts.push(json!({
492 "functionResponse": {
493 "name": func_name,
494 "response": {
495 "content": tool_msg.content,
496 }
497 }
498 }));
499 if !tool_msg.images.is_empty() {
500 pending_images.push((call_id, &tool_msg.images));
501 }
502 i += 1;
503 }
504 contents.push(json!({
505 "role": "function",
506 "parts": parts
507 }));
508 for (call_id, images) in pending_images {
511 let mut img_parts: Vec<Value> = vec![
512 json!({"text": format!("[Images from {} tool result]", call_id)}),
513 ];
514 for img in images {
515 img_parts.push(json!({
516 "inlineData": {
517 "mimeType": img.mime_type,
518 "data": img.data,
519 }
520 }));
521 }
522 contents.push(json!({ "role": "user", "parts": img_parts }));
523 }
524 } else {
525 let call_id = msg.tool_call_id.as_deref().unwrap_or("tool");
527 let mut parts: Vec<Value> =
528 vec![json!({"text": format!("[{}] {}", call_id, msg.content)})];
529 for img in &msg.images {
530 parts.push(json!({
531 "inlineData": {
532 "mimeType": img.mime_type,
533 "data": img.data,
534 }
535 }));
536 }
537 contents.push(json!({
538 "role": "user",
539 "parts": parts
540 }));
541 i += 1;
542 }
543 }
544 }
545 }
546
547 contents
548 }
549
550 fn extract_system(&self, messages: &[Message]) -> Option<String> {
551 let system_parts: Vec<&str> = messages
552 .iter()
553 .filter(|m| m.role == Role::System)
554 .map(|m| m.content.as_str())
555 .collect();
556
557 if system_parts.is_empty() {
558 None
559 } else {
560 Some(system_parts.join("\n\n"))
561 }
562 }
563
564 fn parse_response<T: DeserializeOwned>(
565 &self,
566 body: &Value,
567 rate_limit: Option<RateLimitInfo>,
568 ) -> Result<SgrResponse<T>, SgrError> {
569 let mut output: Option<T> = None;
570 let mut tool_calls = Vec::new();
571 let mut raw_text = String::new();
572 let mut call_counter: u32 = 0;
573
574 let usage = body.get("usageMetadata").and_then(|u| {
576 Some(Usage {
577 prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
578 completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
579 total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
580 })
581 });
582
583 let candidates = body
585 .get("candidates")
586 .and_then(|c| c.as_array())
587 .ok_or(SgrError::EmptyResponse)?;
588
589 for candidate in candidates {
590 let parts = candidate
591 .get("content")
592 .and_then(|c| c.get("parts"))
593 .and_then(|p| p.as_array());
594
595 if let Some(parts) = parts {
596 for part in parts {
597 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
599 raw_text.push_str(text);
600 if output.is_none() {
601 match serde_json::from_str::<T>(text) {
602 Ok(parsed) => output = Some(parsed),
603 Err(e) => {
604 tracing::warn!(error = %e, "failed to parse structured output");
605 }
606 }
607 }
608 }
609
610 if let Some(fc) = part.get("functionCall") {
612 let name = fc
613 .get("name")
614 .and_then(|n| n.as_str())
615 .unwrap_or("unknown")
616 .to_string();
617 let args = fc.get("args").cloned().unwrap_or(json!({}));
618 call_counter += 1;
619 tool_calls.push(ToolCall {
620 id: format!("call#{}#{}", name, call_counter),
621 name,
622 arguments: args,
623 });
624 }
625 }
626 }
627 }
628
629 if output.is_none() && tool_calls.is_empty() {
630 return Err(SgrError::EmptyResponse);
631 }
632
633 Ok(SgrResponse {
634 output,
635 tool_calls,
636 raw_text,
637 usage,
638 rate_limit,
639 })
640 }
641
642 fn extract_raw_text(&self, body: &Value) -> String {
643 let mut text = String::new();
644 if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
645 for candidate in candidates {
646 if let Some(parts) = candidate
647 .get("content")
648 .and_then(|c| c.get("parts"))
649 .and_then(|p| p.as_array())
650 {
651 for part in parts {
652 if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
653 text.push_str(t);
654 }
655 }
656 }
657 }
658 }
659 text
660 }
661
662 fn extract_tool_calls(&self, body: &Value) -> Vec<ToolCall> {
663 let mut calls = Vec::new();
664
665 if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
666 for candidate in candidates {
667 if let Some(parts) = candidate
669 .get("content")
670 .and_then(|c| c.get("parts"))
671 .and_then(|p| p.as_array())
672 {
673 let mut call_counter = 0u32;
674 for part in parts {
675 if let Some(fc) = part.get("functionCall") {
676 let name = fc
677 .get("name")
678 .and_then(|n| n.as_str())
679 .unwrap_or("unknown")
680 .to_string();
681 let args = fc.get("args").cloned().unwrap_or(json!({}));
682 call_counter += 1;
683 calls.push(ToolCall {
684 id: format!("call#{}#{}", name, call_counter),
685 name,
686 arguments: args,
687 });
688 }
689 }
690 }
691
692 if calls.is_empty()
695 && let Some(msg) = candidate.get("finishMessage").and_then(|m| m.as_str())
696 {
697 tracing::debug!(finish_message = msg, "parsing finishMessage for tool calls");
698 if let Some(json_start) = msg.find('{') {
699 let json_str = &msg[json_start..];
700 let json_str = if let Some(end) = json_str.rfind('}') {
702 &json_str[..=end]
703 } else {
704 json_str
705 };
706 if let Ok(tc_json) = serde_json::from_str::<Value>(json_str) {
707 let items: Vec<Value> = if let Some(actions) =
711 tc_json.get("actions").and_then(|a| a.as_array())
712 {
713 actions.clone()
714 } else {
715 vec![tc_json]
716 };
717 for item in items {
718 let name = item
719 .get("tool_name")
720 .and_then(|n| n.as_str())
721 .unwrap_or("unknown")
722 .to_string();
723 let mut args = item.clone();
724 if let Some(obj) = args.as_object_mut() {
725 obj.remove("tool_name");
726 }
727 calls.push(ToolCall {
728 id: name.clone(),
729 name,
730 arguments: args,
731 });
732 }
733 }
734 }
735 }
736 }
737 }
738
739 calls
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use schemars::JsonSchema;
747 use serde::{Deserialize, Serialize};
748
749 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
750 struct TestResponse {
751 answer: String,
752 confidence: f64,
753 }
754
755 #[test]
756 fn builds_request_with_tools_no_json_mode() {
757 let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
758 let messages = vec![Message::system("You are a helper."), Message::user("Hello")];
759 let tools = vec![crate::tool::tool::<TestResponse>("test_tool", "A test")];
760
761 let body = client
762 .build_request::<TestResponse>(&messages, &tools)
763 .unwrap();
764
765 assert!(body["generationConfig"]["responseSchema"].is_null());
767 assert!(body["generationConfig"]["responseMimeType"].is_null());
768
769 assert!(body["tools"][0]["functionDeclarations"].is_array());
771 assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
772
773 assert!(body["systemInstruction"]["parts"][0]["text"].is_string());
775
776 let contents = body["contents"].as_array().unwrap();
778 assert_eq!(contents.len(), 1);
779 assert_eq!(contents[0]["role"], "user");
780 }
781
782 #[test]
783 fn builds_request_without_tools_has_json_mode() {
784 let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
785 let messages = vec![Message::user("Hello")];
786
787 let body = client
788 .build_request::<TestResponse>(&messages, &[])
789 .unwrap();
790
791 assert!(body["generationConfig"]["responseSchema"].is_object());
793 assert_eq!(
794 body["generationConfig"]["responseMimeType"],
795 "application/json"
796 );
797 assert!(body["tools"].is_null());
798 }
799
800 #[test]
801 fn parses_text_response() {
802 let client = GeminiClient::from_api_key("test", "test");
803 let body = json!({
804 "candidates": [{
805 "content": {
806 "parts": [{
807 "text": "{\"answer\": \"42\", \"confidence\": 0.95}"
808 }]
809 }
810 }],
811 "usageMetadata": {
812 "promptTokenCount": 10,
813 "candidatesTokenCount": 20,
814 "totalTokenCount": 30,
815 }
816 });
817
818 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
819 let output = result.output.unwrap();
820 assert_eq!(output.answer, "42");
821 assert_eq!(output.confidence, 0.95);
822 assert!(result.tool_calls.is_empty());
823 assert_eq!(result.usage.unwrap().total_tokens, 30);
824 }
825
826 #[test]
827 fn parses_function_call_response() {
828 let client = GeminiClient::from_api_key("test", "test");
829 let body = json!({
830 "candidates": [{
831 "content": {
832 "parts": [{
833 "functionCall": {
834 "name": "test_tool",
835 "args": {"input": "/video.mp4"}
836 }
837 }]
838 }
839 }]
840 });
841
842 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
843 assert!(result.output.is_none());
844 assert_eq!(result.tool_calls.len(), 1);
845 assert_eq!(result.tool_calls[0].name, "test_tool");
846 assert_eq!(result.tool_calls[0].arguments["input"], "/video.mp4");
847 assert_eq!(result.tool_calls[0].id, "call#test_tool#1");
849 }
850
851 #[test]
852 fn multiple_function_calls_get_unique_ids() {
853 let client = GeminiClient::from_api_key("test", "test");
854 let body = json!({
855 "candidates": [{
856 "content": {
857 "parts": [
858 {"functionCall": {"name": "read_file", "args": {"path": "a.rs"}}},
859 {"functionCall": {"name": "read_file", "args": {"path": "b.rs"}}},
860 {"functionCall": {"name": "write_file", "args": {"path": "c.rs"}}},
861 ]
862 }
863 }]
864 });
865
866 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
867 assert_eq!(result.tool_calls.len(), 3);
868 assert_eq!(result.tool_calls[0].id, "call#read_file#1");
869 assert_eq!(result.tool_calls[1].id, "call#read_file#2");
870 assert_eq!(result.tool_calls[2].id, "call#write_file#3");
871 let ids: std::collections::HashSet<_> = result.tool_calls.iter().map(|tc| &tc.id).collect();
873 assert_eq!(ids.len(), 3);
874 }
875
876 #[test]
877 fn func_name_extraction_from_call_id() {
878 let client = GeminiClient::from_api_key("test", "test");
879
880 let messages = vec![
883 Message::user("test"),
884 Message::tool("call#write_file#1", "Wrote file"),
885 Message::tool("call#bash#2", "Output"),
886 Message::tool("call#my_custom_tool#10", "Result"),
887 Message::tool("old_format_id", "Legacy"), ];
889
890 let contents = client.messages_to_contents(&messages);
891 assert_eq!(contents.len(), 2, "consecutive tools should be grouped");
893 assert_eq!(contents[1]["role"], "function");
894
895 let parts = contents[1]["parts"].as_array().unwrap();
896 assert_eq!(parts.len(), 4);
897 assert_eq!(parts[0]["functionResponse"]["name"], "write_file");
898 assert_eq!(parts[1]["functionResponse"]["name"], "bash");
899 assert_eq!(parts[2]["functionResponse"]["name"], "my_custom_tool");
900 assert_eq!(parts[3]["functionResponse"]["name"], "old_format_id");
901 }
902
903 #[test]
904 fn tool_messages_separated_by_model_not_grouped() {
905 let client = GeminiClient::from_api_key("test", "test");
906
907 let messages = vec![
909 Message::user("test"),
910 Message::tool("call#read#1", "file A"),
911 Message::assistant("thinking..."),
912 Message::tool("call#read#2", "file B"),
913 ];
914
915 let contents = client.messages_to_contents(&messages);
916 assert_eq!(contents.len(), 4);
918 assert_eq!(contents[1]["parts"].as_array().unwrap().len(), 1);
919 assert_eq!(contents[3]["parts"].as_array().unwrap().len(), 1);
920 }
921}