1use crate::schema::response_schema_for;
14use crate::tool::ToolDef;
15use crate::types::*;
16use schemars::JsonSchema;
17use serde::de::DeserializeOwned;
18use serde_json::{json, Value};
19
20pub struct GeminiClient {
22 config: ProviderConfig,
23 http: reqwest::Client,
24}
25
26impl GeminiClient {
27 pub fn new(config: ProviderConfig) -> Self {
28 Self {
29 config,
30 http: reqwest::Client::new(),
31 }
32 }
33
34 pub fn from_api_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
36 Self::new(ProviderConfig::gemini(api_key, model))
37 }
38
39 pub async fn call<T: JsonSchema + DeserializeOwned>(
47 &self,
48 messages: &[Message],
49 tools: &[ToolDef],
50 ) -> Result<SgrResponse<T>, SgrError> {
51 let body = self.build_request::<T>(messages, tools)?;
52 let url = self.build_url();
53
54 tracing::debug!(url = %url, model = %self.config.model, "gemini_request");
55
56 let mut req = self.http.post(&url).json(&body);
57 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
58 req = req.bearer_auth(&self.config.api_key);
59 }
60 let response = req.send().await?;
61
62 let status = response.status().as_u16();
63 let headers = response.headers().clone();
64 if status != 200 {
65 let body = response.text().await.unwrap_or_default();
66 return Err(SgrError::from_response_parts(status, body, &headers));
67 }
68
69 let response_body: Value = response.json().await?;
70 let rate_limit = RateLimitInfo::from_headers(&headers);
71 self.parse_response(&response_body, rate_limit)
72 }
73
74 pub async fn structured<T: JsonSchema + DeserializeOwned>(
78 &self,
79 messages: &[Message],
80 ) -> Result<T, SgrError> {
81 let resp = self.call::<T>(messages, &[]).await?;
82 resp.output.ok_or(SgrError::EmptyResponse)
83 }
84
85 pub async fn flexible<T: JsonSchema + DeserializeOwned>(
93 &self,
94 messages: &[Message],
95 ) -> Result<SgrResponse<T>, SgrError> {
96 let contents = self.messages_to_contents_text(messages);
99 let mut system_instruction = self.extract_system(messages);
100
101 let schema = response_schema_for::<T>();
103 let schema_hint = format!(
104 "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks.",
105 serde_json::to_string_pretty(&schema).unwrap_or_default()
106 );
107 system_instruction = Some(match system_instruction {
108 Some(s) => format!("{}{}", s, schema_hint),
109 None => schema_hint,
110 });
111
112 let mut gen_config = json!({
113 "temperature": self.config.temperature,
114 });
115 if let Some(max_tokens) = self.config.max_tokens {
116 gen_config["maxOutputTokens"] = json!(max_tokens);
117 }
118
119 let mut body = json!({
120 "contents": contents,
121 "generationConfig": gen_config,
122 });
123 if let Some(system) = system_instruction {
124 body["systemInstruction"] = json!({
125 "parts": [{"text": system}]
126 });
127 }
128
129 let url = self.build_url();
130 let mut req = self.http.post(&url).json(&body);
131 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
132 req = req.bearer_auth(&self.config.api_key);
133 }
134 let response = req.send().await?;
135 let status = response.status().as_u16();
136 let headers = response.headers().clone();
137 if status != 200 {
138 let body = response.text().await.unwrap_or_default();
139 return Err(SgrError::from_response_parts(status, body, &headers));
140 }
141
142 let response_body: Value = response.json().await?;
143 let rate_limit = RateLimitInfo::from_headers(&headers);
144
145 let raw_text = self.extract_raw_text(&response_body);
147 if raw_text.trim().is_empty() {
148 if let Some(candidate) = response_body.get("candidates").and_then(|c| c.get(0)) {
150 let reason = candidate
151 .get("finishReason")
152 .and_then(|r| r.as_str())
153 .unwrap_or("unknown");
154 tracing::warn!(
155 finish_reason = reason,
156 has_parts = candidate
157 .get("content")
158 .and_then(|c| c.get("parts"))
159 .is_some(),
160 "empty raw_text from Gemini"
161 );
162 }
163 }
164 let usage = response_body.get("usageMetadata").and_then(|u| {
165 Some(Usage {
166 prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
167 completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
168 total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
169 })
170 });
171
172 let tool_calls = self.extract_tool_calls(&response_body);
175
176 let output = crate::flexible_parser::parse_flexible_coerced::<T>(&raw_text)
180 .map(|r| r.value)
181 .ok();
182
183 if output.is_none() && raw_text.trim().is_empty() && tool_calls.is_empty() {
184 let parts_summary = response_body
186 .get("candidates")
187 .and_then(|c| c.get(0))
188 .and_then(|c| c.get("content"))
189 .and_then(|c| c.get("parts"))
190 .and_then(|p| p.as_array())
191 .map(|parts| {
192 parts
193 .iter()
194 .map(|p| {
195 if p.get("text").is_some() {
196 "text".to_string()
197 } else if p.get("functionCall").is_some() {
198 format!(
199 "functionCall:{}",
200 p["functionCall"]["name"].as_str().unwrap_or("?")
201 )
202 } else {
203 format!("unknown:{}", p)
204 }
205 })
206 .collect::<Vec<_>>()
207 .join(", ")
208 })
209 .unwrap_or_else(|| "no parts".into());
210 let candidate_json = response_body
212 .get("candidates")
213 .and_then(|c| c.get(0))
214 .map(|c| serde_json::to_string_pretty(c).unwrap_or_default())
215 .unwrap_or_else(|| "no candidates".into());
216 tracing::error!(
217 parts = parts_summary,
218 candidate = candidate_json.as_str(),
219 "SGR empty response"
220 );
221 return Err(SgrError::Schema(format!(
222 "Empty response from model (parts: {})",
223 parts_summary
224 )));
225 }
226
227 Ok(SgrResponse {
228 output,
229 tool_calls,
230 raw_text,
231 usage,
232 rate_limit,
233 })
234 }
235
236 pub async fn tools_call(
240 &self,
241 messages: &[Message],
242 tools: &[ToolDef],
243 ) -> Result<Vec<ToolCall>, SgrError> {
244 let body = self.build_tools_only_request(messages, tools)?;
245 let url = self.build_url();
246
247 let mut req = self.http.post(&url).json(&body);
248 if self.config.project_id.is_some() && !self.config.api_key.is_empty() {
249 req = req.bearer_auth(&self.config.api_key);
250 }
251 let response = req.send().await?;
252 let status = response.status().as_u16();
253 let headers = response.headers().clone();
254 if status != 200 {
255 let body = response.text().await.unwrap_or_default();
256 return Err(SgrError::from_response_parts(status, body, &headers));
257 }
258
259 let response_body: Value = response.json().await?;
260 Ok(self.extract_tool_calls(&response_body))
261 }
262
263 fn build_url(&self) -> String {
266 if let Some(project_id) = &self.config.project_id {
267 let location = self.config.location.as_deref().unwrap_or("global");
269 let host = if location == "global" {
270 "aiplatform.googleapis.com".to_string()
271 } else {
272 format!("{location}-aiplatform.googleapis.com")
273 };
274 format!(
275 "https://{host}/v1/projects/{project_id}/locations/{location}/publishers/google/models/{}:generateContent",
276 self.config.model
277 )
278 } else {
279 format!(
281 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
282 self.config.model, self.config.api_key
283 )
284 }
285 }
286
287 fn build_request<T: JsonSchema>(
288 &self,
289 messages: &[Message],
290 tools: &[ToolDef],
291 ) -> Result<Value, SgrError> {
292 let contents = if tools.is_empty() {
294 self.messages_to_contents_text(messages)
295 } else {
296 self.messages_to_contents(messages)
297 };
298 let system_instruction = self.extract_system(messages);
299
300 let mut gen_config = json!({
303 "temperature": self.config.temperature,
304 });
305
306 if tools.is_empty() {
307 gen_config["responseMimeType"] = json!("application/json");
308 gen_config["responseSchema"] = response_schema_for::<T>();
309 }
310
311 if let Some(max_tokens) = self.config.max_tokens {
312 gen_config["maxOutputTokens"] = json!(max_tokens);
313 }
314
315 let mut body = json!({
316 "contents": contents,
317 "generationConfig": gen_config,
318 });
319
320 if let Some(system) = system_instruction {
321 body["systemInstruction"] = json!({
322 "parts": [{"text": system}]
323 });
324 }
325
326 if !tools.is_empty() {
327 let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
328 body["tools"] = json!([{
329 "functionDeclarations": function_declarations,
330 }]);
331 body["toolConfig"] = json!({
332 "functionCallingConfig": {
333 "mode": "AUTO"
334 }
335 });
336 }
337
338 Ok(body)
339 }
340
341 fn build_tools_only_request(
342 &self,
343 messages: &[Message],
344 tools: &[ToolDef],
345 ) -> Result<Value, SgrError> {
346 let contents = self.messages_to_contents(messages);
347 let system_instruction = self.extract_system(messages);
348
349 let mut gen_config = json!({
350 "temperature": self.config.temperature,
351 });
352 if let Some(max_tokens) = self.config.max_tokens {
353 gen_config["maxOutputTokens"] = json!(max_tokens);
354 }
355
356 let function_declarations: Vec<Value> = tools.iter().map(|t| t.to_gemini()).collect();
357
358 let mut body = json!({
359 "contents": contents,
360 "generationConfig": gen_config,
361 "tools": [{
362 "functionDeclarations": function_declarations,
363 }],
364 "toolConfig": {
365 "functionCallingConfig": {
366 "mode": "ANY"
367 }
368 }
369 });
370
371 if let Some(system) = system_instruction {
372 body["systemInstruction"] = json!({
373 "parts": [{"text": system}]
374 });
375 }
376
377 Ok(body)
378 }
379
380 fn messages_to_contents(&self, messages: &[Message]) -> Vec<Value> {
386 self.messages_to_contents_inner(messages, true)
387 }
388
389 fn messages_to_contents_text(&self, messages: &[Message]) -> Vec<Value> {
390 self.messages_to_contents_inner(messages, false)
391 }
392
393 fn messages_to_contents_inner(
394 &self,
395 messages: &[Message],
396 use_function_response: bool,
397 ) -> Vec<Value> {
398 let mut contents = Vec::new();
399
400 let mut i = 0;
401 while i < messages.len() {
402 let msg = &messages[i];
403 match msg.role {
404 Role::System => {
405 i += 1;
406 } Role::User => {
408 contents.push(json!({
409 "role": "user",
410 "parts": [{"text": msg.content}]
411 }));
412 i += 1;
413 }
414 Role::Assistant => {
415 if use_function_response && !msg.tool_calls.is_empty() {
416 let mut parts: Vec<Value> = Vec::new();
418 if !msg.content.is_empty() {
419 parts.push(json!({"text": msg.content}));
420 }
421 for tc in &msg.tool_calls {
422 parts.push(json!({
423 "functionCall": {
424 "name": tc.name,
425 "args": tc.arguments
426 }
427 }));
428 }
429 contents.push(json!({
430 "role": "model",
431 "parts": parts
432 }));
433 } else {
434 contents.push(json!({
435 "role": "model",
436 "parts": [{"text": msg.content}]
437 }));
438 }
439 i += 1;
440 }
441 Role::Tool => {
442 if use_function_response {
443 let mut parts = Vec::new();
446 while i < messages.len() && messages[i].role == Role::Tool {
447 let tool_msg = &messages[i];
448 let call_id = tool_msg.tool_call_id.as_deref().unwrap_or("unknown");
449 let func_name = match call_id.split('#').collect::<Vec<_>>().as_slice()
450 {
451 ["call", name, _counter] => *name,
452 _ => call_id,
453 };
454 parts.push(json!({
455 "functionResponse": {
456 "name": func_name,
457 "response": {
458 "content": tool_msg.content,
459 }
460 }
461 }));
462 i += 1;
463 }
464 contents.push(json!({
465 "role": "function",
466 "parts": parts
467 }));
468 } else {
469 let call_id = msg.tool_call_id.as_deref().unwrap_or("tool");
471 contents.push(json!({
472 "role": "user",
473 "parts": [{"text": format!("[{}] {}", call_id, msg.content)}]
474 }));
475 i += 1;
476 }
477 }
478 }
479 }
480
481 contents
482 }
483
484 fn extract_system(&self, messages: &[Message]) -> Option<String> {
485 let system_parts: Vec<&str> = messages
486 .iter()
487 .filter(|m| m.role == Role::System)
488 .map(|m| m.content.as_str())
489 .collect();
490
491 if system_parts.is_empty() {
492 None
493 } else {
494 Some(system_parts.join("\n\n"))
495 }
496 }
497
498 fn parse_response<T: DeserializeOwned>(
499 &self,
500 body: &Value,
501 rate_limit: Option<RateLimitInfo>,
502 ) -> Result<SgrResponse<T>, SgrError> {
503 let mut output: Option<T> = None;
504 let mut tool_calls = Vec::new();
505 let mut raw_text = String::new();
506 let mut call_counter: u32 = 0;
507
508 let usage = body.get("usageMetadata").and_then(|u| {
510 Some(Usage {
511 prompt_tokens: u.get("promptTokenCount")?.as_u64()? as u32,
512 completion_tokens: u.get("candidatesTokenCount")?.as_u64()? as u32,
513 total_tokens: u.get("totalTokenCount")?.as_u64()? as u32,
514 })
515 });
516
517 let candidates = body
519 .get("candidates")
520 .and_then(|c| c.as_array())
521 .ok_or(SgrError::EmptyResponse)?;
522
523 for candidate in candidates {
524 let parts = candidate
525 .get("content")
526 .and_then(|c| c.get("parts"))
527 .and_then(|p| p.as_array());
528
529 if let Some(parts) = parts {
530 for part in parts {
531 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
533 raw_text.push_str(text);
534 if output.is_none() {
535 match serde_json::from_str::<T>(text) {
536 Ok(parsed) => output = Some(parsed),
537 Err(e) => {
538 tracing::warn!(error = %e, "failed to parse structured output");
539 }
540 }
541 }
542 }
543
544 if let Some(fc) = part.get("functionCall") {
546 let name = fc
547 .get("name")
548 .and_then(|n| n.as_str())
549 .unwrap_or("unknown")
550 .to_string();
551 let args = fc.get("args").cloned().unwrap_or(json!({}));
552 call_counter += 1;
553 tool_calls.push(ToolCall {
554 id: format!("call#{}#{}", name, call_counter),
555 name,
556 arguments: args,
557 });
558 }
559 }
560 }
561 }
562
563 if output.is_none() && tool_calls.is_empty() {
564 return Err(SgrError::EmptyResponse);
565 }
566
567 Ok(SgrResponse {
568 output,
569 tool_calls,
570 raw_text,
571 usage,
572 rate_limit,
573 })
574 }
575
576 fn extract_raw_text(&self, body: &Value) -> String {
577 let mut text = String::new();
578 if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
579 for candidate in candidates {
580 if let Some(parts) = candidate
581 .get("content")
582 .and_then(|c| c.get("parts"))
583 .and_then(|p| p.as_array())
584 {
585 for part in parts {
586 if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
587 text.push_str(t);
588 }
589 }
590 }
591 }
592 }
593 text
594 }
595
596 fn extract_tool_calls(&self, body: &Value) -> Vec<ToolCall> {
597 let mut calls = Vec::new();
598
599 if let Some(candidates) = body.get("candidates").and_then(|c| c.as_array()) {
600 for candidate in candidates {
601 if let Some(parts) = candidate
603 .get("content")
604 .and_then(|c| c.get("parts"))
605 .and_then(|p| p.as_array())
606 {
607 let mut call_counter = 0u32;
608 for part in parts {
609 if let Some(fc) = part.get("functionCall") {
610 let name = fc
611 .get("name")
612 .and_then(|n| n.as_str())
613 .unwrap_or("unknown")
614 .to_string();
615 let args = fc.get("args").cloned().unwrap_or(json!({}));
616 call_counter += 1;
617 calls.push(ToolCall {
618 id: format!("call#{}#{}", name, call_counter),
619 name,
620 arguments: args,
621 });
622 }
623 }
624 }
625
626 if calls.is_empty() {
629 if let Some(msg) = candidate.get("finishMessage").and_then(|m| m.as_str()) {
630 tracing::debug!(
631 finish_message = msg,
632 "parsing finishMessage for tool calls"
633 );
634 if let Some(json_start) = msg.find('{') {
635 let json_str = &msg[json_start..];
636 let json_str = if let Some(end) = json_str.rfind('}') {
638 &json_str[..=end]
639 } else {
640 json_str
641 };
642 if let Ok(tc_json) = serde_json::from_str::<Value>(json_str) {
643 let items: Vec<Value> = if let Some(actions) =
647 tc_json.get("actions").and_then(|a| a.as_array())
648 {
649 actions.clone()
650 } else {
651 vec![tc_json]
652 };
653 for item in items {
654 let name = item
655 .get("tool_name")
656 .and_then(|n| n.as_str())
657 .unwrap_or("unknown")
658 .to_string();
659 let mut args = item.clone();
660 if let Some(obj) = args.as_object_mut() {
661 obj.remove("tool_name");
662 }
663 calls.push(ToolCall {
664 id: name.clone(),
665 name,
666 arguments: args,
667 });
668 }
669 }
670 }
671 }
672 }
673 }
674 }
675
676 calls
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use schemars::JsonSchema;
684 use serde::{Deserialize, Serialize};
685
686 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
687 struct TestResponse {
688 answer: String,
689 confidence: f64,
690 }
691
692 #[test]
693 fn builds_request_with_tools_no_json_mode() {
694 let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
695 let messages = vec![Message::system("You are a helper."), Message::user("Hello")];
696 let tools = vec![crate::tool::tool::<TestResponse>("test_tool", "A test")];
697
698 let body = client
699 .build_request::<TestResponse>(&messages, &tools)
700 .unwrap();
701
702 assert!(body["generationConfig"]["responseSchema"].is_null());
704 assert!(body["generationConfig"]["responseMimeType"].is_null());
705
706 assert!(body["tools"][0]["functionDeclarations"].is_array());
708 assert_eq!(body["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
709
710 assert!(body["systemInstruction"]["parts"][0]["text"].is_string());
712
713 let contents = body["contents"].as_array().unwrap();
715 assert_eq!(contents.len(), 1);
716 assert_eq!(contents[0]["role"], "user");
717 }
718
719 #[test]
720 fn builds_request_without_tools_has_json_mode() {
721 let client = GeminiClient::from_api_key("test-key", "gemini-2.5-flash");
722 let messages = vec![Message::user("Hello")];
723
724 let body = client
725 .build_request::<TestResponse>(&messages, &[])
726 .unwrap();
727
728 assert!(body["generationConfig"]["responseSchema"].is_object());
730 assert_eq!(
731 body["generationConfig"]["responseMimeType"],
732 "application/json"
733 );
734 assert!(body["tools"].is_null());
735 }
736
737 #[test]
738 fn parses_text_response() {
739 let client = GeminiClient::from_api_key("test", "test");
740 let body = json!({
741 "candidates": [{
742 "content": {
743 "parts": [{
744 "text": "{\"answer\": \"42\", \"confidence\": 0.95}"
745 }]
746 }
747 }],
748 "usageMetadata": {
749 "promptTokenCount": 10,
750 "candidatesTokenCount": 20,
751 "totalTokenCount": 30,
752 }
753 });
754
755 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
756 let output = result.output.unwrap();
757 assert_eq!(output.answer, "42");
758 assert_eq!(output.confidence, 0.95);
759 assert!(result.tool_calls.is_empty());
760 assert_eq!(result.usage.unwrap().total_tokens, 30);
761 }
762
763 #[test]
764 fn parses_function_call_response() {
765 let client = GeminiClient::from_api_key("test", "test");
766 let body = json!({
767 "candidates": [{
768 "content": {
769 "parts": [{
770 "functionCall": {
771 "name": "test_tool",
772 "args": {"input": "/video.mp4"}
773 }
774 }]
775 }
776 }]
777 });
778
779 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
780 assert!(result.output.is_none());
781 assert_eq!(result.tool_calls.len(), 1);
782 assert_eq!(result.tool_calls[0].name, "test_tool");
783 assert_eq!(result.tool_calls[0].arguments["input"], "/video.mp4");
784 assert_eq!(result.tool_calls[0].id, "call#test_tool#1");
786 }
787
788 #[test]
789 fn multiple_function_calls_get_unique_ids() {
790 let client = GeminiClient::from_api_key("test", "test");
791 let body = json!({
792 "candidates": [{
793 "content": {
794 "parts": [
795 {"functionCall": {"name": "read_file", "args": {"path": "a.rs"}}},
796 {"functionCall": {"name": "read_file", "args": {"path": "b.rs"}}},
797 {"functionCall": {"name": "write_file", "args": {"path": "c.rs"}}},
798 ]
799 }
800 }]
801 });
802
803 let result: SgrResponse<TestResponse> = client.parse_response(&body, None).unwrap();
804 assert_eq!(result.tool_calls.len(), 3);
805 assert_eq!(result.tool_calls[0].id, "call#read_file#1");
806 assert_eq!(result.tool_calls[1].id, "call#read_file#2");
807 assert_eq!(result.tool_calls[2].id, "call#write_file#3");
808 let ids: std::collections::HashSet<_> = result.tool_calls.iter().map(|tc| &tc.id).collect();
810 assert_eq!(ids.len(), 3);
811 }
812
813 #[test]
814 fn func_name_extraction_from_call_id() {
815 let client = GeminiClient::from_api_key("test", "test");
816
817 let messages = vec![
820 Message::user("test"),
821 Message::tool("call#write_file#1", "Wrote file"),
822 Message::tool("call#bash#2", "Output"),
823 Message::tool("call#my_custom_tool#10", "Result"),
824 Message::tool("old_format_id", "Legacy"), ];
826
827 let contents = client.messages_to_contents(&messages);
828 assert_eq!(contents.len(), 2, "consecutive tools should be grouped");
830 assert_eq!(contents[1]["role"], "function");
831
832 let parts = contents[1]["parts"].as_array().unwrap();
833 assert_eq!(parts.len(), 4);
834 assert_eq!(parts[0]["functionResponse"]["name"], "write_file");
835 assert_eq!(parts[1]["functionResponse"]["name"], "bash");
836 assert_eq!(parts[2]["functionResponse"]["name"], "my_custom_tool");
837 assert_eq!(parts[3]["functionResponse"]["name"], "old_format_id");
838 }
839
840 #[test]
841 fn tool_messages_separated_by_model_not_grouped() {
842 let client = GeminiClient::from_api_key("test", "test");
843
844 let messages = vec![
846 Message::user("test"),
847 Message::tool("call#read#1", "file A"),
848 Message::assistant("thinking..."),
849 Message::tool("call#read#2", "file B"),
850 ];
851
852 let contents = client.messages_to_contents(&messages);
853 assert_eq!(contents.len(), 4);
855 assert_eq!(contents[1]["parts"].as_array().unwrap().len(), 1);
856 assert_eq!(contents[3]["parts"].as_array().unwrap().len(), 1);
857 }
858}