1use crate::config::constants::{models, urls};
2use crate::gemini::function_calling::{
3 FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
4};
5use crate::gemini::models::SystemInstruction;
6use crate::gemini::streaming::{
7 StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
8};
9use crate::gemini::{
10 Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
11 Tool, ToolConfig,
12};
13use crate::llm::client::LLMClient;
14use crate::llm::error_display;
15use crate::llm::provider::{
16 FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
17 LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
18};
19use crate::llm::types as llm_types;
20use async_stream::try_stream;
21use async_trait::async_trait;
22use reqwest::Client as HttpClient;
23use serde_json::{Map, Value, json};
24use std::collections::HashMap;
25use tokio::sync::mpsc;
26
27pub struct GeminiProvider {
28 api_key: String,
29 http_client: HttpClient,
30 base_url: String,
31 model: String,
32}
33
34impl GeminiProvider {
35 pub fn new(api_key: String) -> Self {
36 Self::with_model(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string())
37 }
38
39 pub fn with_model(api_key: String, model: String) -> Self {
40 Self {
41 api_key,
42 http_client: HttpClient::new(),
43 base_url: urls::GEMINI_API_BASE.to_string(),
44 model,
45 }
46 }
47
48 pub fn from_config(
49 api_key: Option<String>,
50 model: Option<String>,
51 base_url: Option<String>,
52 ) -> Self {
53 let api_key_value = api_key.unwrap_or_default();
54 let mut provider = if let Some(model_value) = model {
55 Self::with_model(api_key_value, model_value)
56 } else {
57 Self::new(api_key_value)
58 };
59 if let Some(base) = base_url {
60 provider.base_url = base;
61 }
62 provider
63 }
64}
65
66#[async_trait]
67impl LLMProvider for GeminiProvider {
68 fn name(&self) -> &str {
69 "gemini"
70 }
71
72 fn supports_streaming(&self) -> bool {
73 true
74 }
75
76 fn supports_reasoning(&self, _model: &str) -> bool {
77 false
78 }
79
80 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
81 let gemini_request = self.convert_to_gemini_request(&request)?;
82
83 let url = format!(
84 "{}/models/{}:generateContent?key={}",
85 self.base_url, request.model, self.api_key
86 );
87
88 let response = self
89 .http_client
90 .post(&url)
91 .json(&gemini_request)
92 .send()
93 .await
94 .map_err(|e| {
95 let formatted_error =
96 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
97 LLMError::Network(formatted_error)
98 })?;
99
100 if !response.status().is_success() {
101 let status = response.status();
102 let error_text = response.text().await.unwrap_or_default();
103
104 if status.as_u16() == 429
106 || error_text.contains("insufficient_quota")
107 || error_text.contains("quota")
108 || error_text.contains("rate limit")
109 {
110 return Err(LLMError::RateLimit);
111 }
112
113 let formatted_error = error_display::format_llm_error(
114 "Gemini",
115 &format!("HTTP {}: {}", status, error_text),
116 );
117 return Err(LLMError::Provider(formatted_error));
118 }
119
120 let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
121 let formatted_error = error_display::format_llm_error(
122 "Gemini",
123 &format!("Failed to parse response: {}", e),
124 );
125 LLMError::Provider(formatted_error)
126 })?;
127
128 Self::convert_from_gemini_response(gemini_response)
129 }
130
131 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
132 let gemini_request = self.convert_to_gemini_request(&request)?;
133
134 let url = format!(
135 "{}/models/{}:streamGenerateContent?key={}",
136 self.base_url, request.model, self.api_key
137 );
138
139 let response = self
140 .http_client
141 .post(&url)
142 .json(&gemini_request)
143 .send()
144 .await
145 .map_err(|e| {
146 let formatted_error =
147 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
148 LLMError::Network(formatted_error)
149 })?;
150
151 if !response.status().is_success() {
152 let status = response.status();
153 let error_text = response.text().await.unwrap_or_default();
154
155 if status.as_u16() == 401 || status.as_u16() == 403 {
156 let formatted_error = error_display::format_llm_error(
157 "Gemini",
158 &format!("HTTP {}: {}", status, error_text),
159 );
160 return Err(LLMError::Authentication(formatted_error));
161 }
162
163 if status.as_u16() == 429
164 || error_text.contains("insufficient_quota")
165 || error_text.contains("quota")
166 || error_text.contains("rate limit")
167 {
168 return Err(LLMError::RateLimit);
169 }
170
171 let formatted_error = error_display::format_llm_error(
172 "Gemini",
173 &format!("HTTP {}: {}", status, error_text),
174 );
175 return Err(LLMError::Provider(formatted_error));
176 }
177
178 let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
179 let completion_sender = event_tx.clone();
180
181 tokio::spawn(async move {
182 let mut processor = StreamingProcessor::new();
183 let token_sender = completion_sender.clone();
184 let mut aggregated_text = String::new();
185 let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
186 if chunk.is_empty() {
187 return Ok(());
188 }
189
190 aggregated_text.push_str(chunk);
191
192 token_sender
193 .send(Ok(LLMStreamEvent::Token {
194 delta: chunk.to_string(),
195 }))
196 .map_err(|_| StreamingError::StreamingError {
197 message: "Streaming consumer dropped".to_string(),
198 partial_content: Some(chunk.to_string()),
199 })?;
200 Ok(())
201 };
202
203 let result = processor.process_stream(response, &mut on_chunk).await;
204 match result {
205 Ok(mut streaming_response) => {
206 if streaming_response.candidates.is_empty()
207 && !aggregated_text.trim().is_empty()
208 {
209 streaming_response.candidates.push(StreamingCandidate {
210 content: Content {
211 role: "model".to_string(),
212 parts: vec![Part::Text {
213 text: aggregated_text.clone(),
214 }],
215 },
216 finish_reason: None,
217 index: Some(0),
218 });
219 }
220
221 match Self::convert_from_streaming_response(streaming_response) {
222 Ok(final_response) => {
223 let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
224 response: final_response,
225 }));
226 }
227 Err(err) => {
228 let _ = completion_sender.send(Err(err));
229 }
230 }
231 }
232 Err(error) => {
233 let mapped = Self::map_streaming_error(error);
234 let _ = completion_sender.send(Err(mapped));
235 }
236 }
237 });
238
239 drop(event_tx);
240
241 let stream = {
242 let mut receiver = event_rx;
243 try_stream! {
244 while let Some(event) = receiver.recv().await {
245 yield event?;
246 }
247 }
248 };
249
250 Ok(Box::pin(stream))
251 }
252
253 fn supported_models(&self) -> Vec<String> {
254 vec![
255 models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
256 models::google::GEMINI_2_5_PRO.to_string(),
257 ]
258 }
259
260 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
261 if !self.supported_models().contains(&request.model) {
262 let formatted_error = error_display::format_llm_error(
263 "Gemini",
264 &format!("Unsupported model: {}", request.model),
265 );
266 return Err(LLMError::InvalidRequest(formatted_error));
267 }
268 Ok(())
269 }
270}
271
272impl GeminiProvider {
273 fn convert_to_gemini_request(
274 &self,
275 request: &LLMRequest,
276 ) -> Result<GenerateContentRequest, LLMError> {
277 let mut call_map: HashMap<String, String> = HashMap::new();
278 for message in &request.messages {
279 if message.role == MessageRole::Assistant
280 && let Some(tool_calls) = &message.tool_calls
281 {
282 for tool_call in tool_calls {
283 call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
284 }
285 }
286 }
287
288 let mut contents: Vec<Content> = Vec::new();
289 for message in &request.messages {
290 if message.role == MessageRole::System {
291 continue;
292 }
293
294 let mut parts: Vec<Part> = Vec::new();
295 if message.role != MessageRole::Tool && !message.content.is_empty() {
296 parts.push(Part::Text {
297 text: message.content.clone(),
298 });
299 }
300
301 if message.role == MessageRole::Assistant
302 && let Some(tool_calls) = &message.tool_calls
303 {
304 for tool_call in tool_calls {
305 let parsed_args = serde_json::from_str(&tool_call.function.arguments)
306 .unwrap_or_else(|_| json!({}));
307 parts.push(Part::FunctionCall {
308 function_call: GeminiFunctionCall {
309 name: tool_call.function.name.clone(),
310 args: parsed_args,
311 id: Some(tool_call.id.clone()),
312 },
313 });
314 }
315 }
316
317 if message.role == MessageRole::Tool {
318 if let Some(tool_call_id) = &message.tool_call_id {
319 let func_name = call_map
320 .get(tool_call_id)
321 .cloned()
322 .unwrap_or_else(|| tool_call_id.clone());
323 let response_text = serde_json::from_str::<Value>(&message.content)
324 .map(|value| {
325 serde_json::to_string_pretty(&value)
326 .unwrap_or_else(|_| message.content.clone())
327 })
328 .unwrap_or_else(|_| message.content.clone());
329
330 let response_payload = json!({
331 "name": func_name.clone(),
332 "content": [{
333 "text": response_text
334 }]
335 });
336
337 parts.push(Part::FunctionResponse {
338 function_response: FunctionResponse {
339 name: func_name,
340 response: response_payload,
341 },
342 });
343 } else if !message.content.is_empty() {
344 parts.push(Part::Text {
345 text: message.content.clone(),
346 });
347 }
348 }
349
350 if !parts.is_empty() {
351 contents.push(Content {
352 role: message.role.as_gemini_str().to_string(),
353 parts,
354 });
355 }
356 }
357
358 let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
359 definitions
360 .iter()
361 .map(|tool| Tool {
362 function_declarations: vec![FunctionDeclaration {
363 name: tool.function.name.clone(),
364 description: tool.function.description.clone(),
365 parameters: tool.function.parameters.clone(),
366 }],
367 })
368 .collect()
369 });
370
371 let mut generation_config = Map::new();
372 if let Some(max_tokens) = request.max_tokens {
373 generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
374 }
375 if let Some(temp) = request.temperature {
376 generation_config.insert("temperature".to_string(), json!(temp));
377 }
378 let has_tools = request
379 .tools
380 .as_ref()
381 .map(|defs| !defs.is_empty())
382 .unwrap_or(false);
383 let tool_config = if has_tools || request.tool_choice.is_some() {
384 Some(match request.tool_choice.as_ref() {
385 Some(ToolChoice::None) => ToolConfig {
386 function_calling_config: FunctionCallingConfig::none(),
387 },
388 Some(ToolChoice::Any) => ToolConfig {
389 function_calling_config: FunctionCallingConfig::any(),
390 },
391 Some(ToolChoice::Specific(spec)) => {
392 let mut config = FunctionCallingConfig::any();
393 if spec.tool_type == "function" {
394 config.allowed_function_names = Some(vec![spec.function.name.clone()]);
395 }
396 ToolConfig {
397 function_calling_config: config,
398 }
399 }
400 _ => ToolConfig::auto(),
401 })
402 } else {
403 None
404 };
405
406 Ok(GenerateContentRequest {
407 contents,
408 tools,
409 tool_config,
410 system_instruction: request
411 .system_prompt
412 .as_ref()
413 .map(|text| SystemInstruction::new(text.clone())),
414 generation_config: if generation_config.is_empty() {
415 None
416 } else {
417 Some(Value::Object(generation_config))
418 },
419 reasoning_config: None,
420 })
421 }
422
423 fn convert_from_gemini_response(
424 response: GenerateContentResponse,
425 ) -> Result<LLMResponse, LLMError> {
426 let mut candidates = response.candidates.into_iter();
427 let candidate = candidates.next().ok_or_else(|| {
428 let formatted_error =
429 error_display::format_llm_error("Gemini", "No candidate in response");
430 LLMError::Provider(formatted_error)
431 })?;
432
433 if candidate.content.parts.is_empty() {
434 return Ok(LLMResponse {
435 content: Some(String::new()),
436 tool_calls: None,
437 usage: None,
438 finish_reason: FinishReason::Stop,
439 reasoning: None,
440 });
441 }
442
443 let mut text_content = String::new();
444 let mut tool_calls = Vec::new();
445
446 for part in candidate.content.parts {
447 match part {
448 Part::Text { text } => {
449 text_content.push_str(&text);
450 }
451 Part::FunctionCall { function_call } => {
452 let call_id = function_call.id.clone().unwrap_or_else(|| {
453 format!(
454 "call_{}_{}",
455 std::time::SystemTime::now()
456 .duration_since(std::time::UNIX_EPOCH)
457 .unwrap_or_default()
458 .as_nanos(),
459 tool_calls.len()
460 )
461 });
462 tool_calls.push(ToolCall {
463 id: call_id,
464 call_type: "function".to_string(),
465 function: FunctionCall {
466 name: function_call.name,
467 arguments: serde_json::to_string(&function_call.args)
468 .unwrap_or_else(|_| "{}".to_string()),
469 },
470 });
471 }
472 Part::FunctionResponse { .. } => {
473 }
475 }
476 }
477
478 let finish_reason = match candidate.finish_reason.as_deref() {
479 Some("STOP") => FinishReason::Stop,
480 Some("MAX_TOKENS") => FinishReason::Length,
481 Some("SAFETY") => FinishReason::ContentFilter,
482 Some("FUNCTION_CALL") => FinishReason::ToolCalls,
483 Some(other) => FinishReason::Error(other.to_string()),
484 None => FinishReason::Stop,
485 };
486
487 Ok(LLMResponse {
488 content: if text_content.is_empty() {
489 None
490 } else {
491 Some(text_content)
492 },
493 tool_calls: if tool_calls.is_empty() {
494 None
495 } else {
496 Some(tool_calls)
497 },
498 usage: None,
499 finish_reason,
500 reasoning: None,
501 })
502 }
503
504 fn convert_from_streaming_response(
505 response: StreamingResponse,
506 ) -> Result<LLMResponse, LLMError> {
507 let converted_candidates: Vec<Candidate> = response
508 .candidates
509 .into_iter()
510 .map(|candidate| Candidate {
511 content: candidate.content,
512 finish_reason: candidate.finish_reason,
513 })
514 .collect();
515
516 let converted = GenerateContentResponse {
517 candidates: converted_candidates,
518 prompt_feedback: None,
519 usage_metadata: response.usage_metadata,
520 };
521
522 Self::convert_from_gemini_response(converted)
523 }
524
525 fn map_streaming_error(error: StreamingError) -> LLMError {
526 match error {
527 StreamingError::NetworkError { message, .. } => {
528 let formatted = error_display::format_llm_error(
529 "Gemini",
530 &format!("Network error: {}", message),
531 );
532 LLMError::Network(formatted)
533 }
534 StreamingError::ApiError {
535 status_code,
536 message,
537 ..
538 } => {
539 if status_code == 401 || status_code == 403 {
540 let formatted = error_display::format_llm_error(
541 "Gemini",
542 &format!("HTTP {}: {}", status_code, message),
543 );
544 LLMError::Authentication(formatted)
545 } else if status_code == 429 {
546 LLMError::RateLimit
547 } else {
548 let formatted = error_display::format_llm_error(
549 "Gemini",
550 &format!("API error ({}): {}", status_code, message),
551 );
552 LLMError::Provider(formatted)
553 }
554 }
555 StreamingError::ParseError { message, .. } => {
556 let formatted =
557 error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
558 LLMError::Provider(formatted)
559 }
560 StreamingError::TimeoutError {
561 operation,
562 duration,
563 } => {
564 let formatted = error_display::format_llm_error(
565 "Gemini",
566 &format!(
567 "Streaming timeout during {} after {:?}",
568 operation, duration
569 ),
570 );
571 LLMError::Network(formatted)
572 }
573 StreamingError::ContentError { message } => {
574 let formatted = error_display::format_llm_error(
575 "Gemini",
576 &format!("Content error: {}", message),
577 );
578 LLMError::Provider(formatted)
579 }
580 StreamingError::StreamingError { message, .. } => {
581 let formatted = error_display::format_llm_error(
582 "Gemini",
583 &format!("Streaming error: {}", message),
584 );
585 LLMError::Provider(formatted)
586 }
587 }
588 }
589}
590
591#[async_trait]
592impl LLMClient for GeminiProvider {
593 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
594 let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
596 match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
598 Ok(gemini_request) => {
599 let mut messages = Vec::new();
601 let mut system_prompt = None;
602
603 for content in &gemini_request.contents {
605 let role = match content.role.as_str() {
606 crate::config::constants::message_roles::USER => MessageRole::User,
607 "model" => MessageRole::Assistant,
608 crate::config::constants::message_roles::SYSTEM => {
609 let text = content
611 .parts
612 .iter()
613 .filter_map(|part| part.as_text())
614 .collect::<Vec<_>>()
615 .join("");
616 system_prompt = Some(text);
617 continue;
618 }
619 _ => MessageRole::User, };
621
622 let content_text = content
623 .parts
624 .iter()
625 .filter_map(|part| part.as_text())
626 .collect::<Vec<_>>()
627 .join("");
628
629 messages.push(Message {
630 role,
631 content: content_text,
632 tool_calls: None,
633 tool_call_id: None,
634 });
635 }
636
637 let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
639 gemini_tools
640 .iter()
641 .flat_map(|tool| &tool.function_declarations)
642 .map(|decl| crate::llm::provider::ToolDefinition {
643 tool_type: "function".to_string(),
644 function: crate::llm::provider::FunctionDefinition {
645 name: decl.name.clone(),
646 description: decl.description.clone(),
647 parameters: decl.parameters.clone(),
648 },
649 })
650 .collect::<Vec<_>>()
651 });
652
653 let llm_request = LLMRequest {
654 messages,
655 system_prompt,
656 tools,
657 model: self.model.clone(),
658 max_tokens: gemini_request
659 .generation_config
660 .as_ref()
661 .and_then(|config| config.get("maxOutputTokens"))
662 .and_then(|v| v.as_u64())
663 .map(|v| v as u32),
664 temperature: gemini_request
665 .generation_config
666 .as_ref()
667 .and_then(|config| config.get("temperature"))
668 .and_then(|v| v.as_f64())
669 .map(|v| v as f32),
670 stream: false,
671 tool_choice: None,
672 parallel_tool_calls: None,
673 parallel_tool_config: None,
674 reasoning_effort: None,
675 };
676
677 let response = LLMProvider::generate(self, llm_request).await?;
679
680 let content = if let Some(tool_calls) = &response.tool_calls {
682 if !tool_calls.is_empty() {
683 let tool_call_json = json!({
685 "tool_calls": tool_calls.iter().map(|tc| {
686 json!({
687 "function": {
688 "name": tc.function.name,
689 "arguments": tc.function.arguments
690 }
691 })
692 }).collect::<Vec<_>>()
693 });
694 tool_call_json.to_string()
695 } else {
696 response.content.unwrap_or("".to_string())
697 }
698 } else {
699 response.content.unwrap_or("".to_string())
700 };
701
702 return Ok(llm_types::LLMResponse {
703 content,
704 model: self.model.clone(),
705 usage: response.usage.map(|u| llm_types::Usage {
706 prompt_tokens: u.prompt_tokens as usize,
707 completion_tokens: u.completion_tokens as usize,
708 total_tokens: u.total_tokens as usize,
709 }),
710 reasoning: response.reasoning,
711 });
712 }
713 Err(_) => {
714 LLMRequest {
716 messages: vec![Message {
717 role: MessageRole::User,
718 content: prompt.to_string(),
719 tool_calls: None,
720 tool_call_id: None,
721 }],
722 system_prompt: None,
723 tools: None,
724 model: self.model.clone(),
725 max_tokens: None,
726 temperature: None,
727 stream: false,
728 tool_choice: None,
729 parallel_tool_calls: None,
730 parallel_tool_config: None,
731 reasoning_effort: None,
732 }
733 }
734 }
735 } else {
736 LLMRequest {
738 messages: vec![Message {
739 role: MessageRole::User,
740 content: prompt.to_string(),
741 tool_calls: None,
742 tool_call_id: None,
743 }],
744 system_prompt: None,
745 tools: None,
746 model: self.model.clone(),
747 max_tokens: None,
748 temperature: None,
749 stream: false,
750 tool_choice: None,
751 parallel_tool_calls: None,
752 parallel_tool_config: None,
753 reasoning_effort: None,
754 }
755 };
756
757 let response = LLMProvider::generate(self, request).await?;
758
759 Ok(llm_types::LLMResponse {
760 content: response.content.unwrap_or("".to_string()),
761 model: self.model.clone(),
762 usage: response.usage.map(|u| llm_types::Usage {
763 prompt_tokens: u.prompt_tokens as usize,
764 completion_tokens: u.completion_tokens as usize,
765 total_tokens: u.total_tokens as usize,
766 }),
767 reasoning: response.reasoning,
768 })
769 }
770
771 fn backend_kind(&self) -> llm_types::BackendKind {
772 llm_types::BackendKind::Gemini
773 }
774
775 fn model_id(&self) -> &str {
776 &self.model
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use super::*;
783 use crate::config::constants::models;
784 use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
785
786 #[test]
787 fn convert_to_gemini_request_maps_history_and_system_prompt() {
788 let provider = GeminiProvider::new("test-key".to_string());
789 let mut assistant_message = Message::assistant("Sure thing".to_string());
790 assistant_message.tool_calls = Some(vec![ToolCall::function(
791 "call_1".to_string(),
792 "list_files".to_string(),
793 json!({ "path": "." }).to_string(),
794 )]);
795
796 let tool_response =
797 Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
798
799 let tool_def = ToolDefinition::function(
800 "list_files".to_string(),
801 "List files".to_string(),
802 json!({
803 "type": "object",
804 "properties": {
805 "path": { "type": "string" }
806 }
807 }),
808 );
809
810 let request = LLMRequest {
811 messages: vec![
812 Message::user("hello".to_string()),
813 assistant_message,
814 tool_response,
815 ],
816 system_prompt: Some("System prompt".to_string()),
817 tools: Some(vec![tool_def]),
818 model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
819 max_tokens: Some(256),
820 temperature: Some(0.4),
821 stream: false,
822 tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
823 tool_type: "function".to_string(),
824 function: SpecificFunctionChoice {
825 name: "list_files".to_string(),
826 },
827 })),
828 parallel_tool_calls: None,
829 parallel_tool_config: None,
830 reasoning_effort: None,
831 };
832
833 let gemini_request = provider
834 .convert_to_gemini_request(&request)
835 .expect("conversion should succeed");
836
837 let system_instruction = gemini_request
838 .system_instruction
839 .expect("system instruction should be present");
840 assert!(matches!(
841 system_instruction.parts.as_slice(),
842 [Part::Text { text }] if text == "System prompt"
843 ));
844
845 assert_eq!(gemini_request.contents.len(), 3);
846 assert_eq!(gemini_request.contents[0].role, "user");
847 assert!(
848 gemini_request.contents[1]
849 .parts
850 .iter()
851 .any(|part| matches!(part, Part::FunctionCall { .. }))
852 );
853 let tool_part = gemini_request.contents[2]
854 .parts
855 .iter()
856 .find_map(|part| match part {
857 Part::FunctionResponse { function_response } => Some(function_response),
858 _ => None,
859 })
860 .expect("tool response part should exist");
861 assert_eq!(tool_part.name, "list_files");
862 }
863
864 #[test]
865 fn convert_from_gemini_response_extracts_tool_calls() {
866 let response = GenerateContentResponse {
867 candidates: vec![crate::gemini::Candidate {
868 content: Content {
869 role: "model".to_string(),
870 parts: vec![
871 Part::Text {
872 text: "Here you go".to_string(),
873 },
874 Part::FunctionCall {
875 function_call: GeminiFunctionCall {
876 name: "list_files".to_string(),
877 args: json!({ "path": "." }),
878 id: Some("call_1".to_string()),
879 },
880 },
881 ],
882 },
883 finish_reason: Some("FUNCTION_CALL".to_string()),
884 }],
885 prompt_feedback: None,
886 usage_metadata: None,
887 };
888
889 let llm_response = GeminiProvider::convert_from_gemini_response(response)
890 .expect("conversion should succeed");
891
892 assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
893 let calls = llm_response
894 .tool_calls
895 .expect("tool call should be present");
896 assert_eq!(calls.len(), 1);
897 assert_eq!(calls[0].function.name, "list_files");
898 assert!(calls[0].function.arguments.contains("path"));
899 assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
900 }
901}