vtcode_core/llm/providers/
gemini.rs

1use crate::config::constants::{models, urls};
2use crate::gemini::function_calling::{
3    FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
4};
5use crate::gemini::models::SystemInstruction;
6use crate::gemini::streaming::{
7    StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
8};
9use crate::gemini::{
10    Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
11    Tool, ToolConfig,
12};
13use crate::llm::client::LLMClient;
14use crate::llm::error_display;
15use crate::llm::provider::{
16    FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
17    LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
18};
19use crate::llm::types as llm_types;
20use async_stream::try_stream;
21use async_trait::async_trait;
22use reqwest::Client as HttpClient;
23use serde_json::{Map, Value, json};
24use std::collections::HashMap;
25use tokio::sync::mpsc;
26
27pub struct GeminiProvider {
28    api_key: String,
29    http_client: HttpClient,
30    base_url: String,
31    model: String,
32}
33
34impl GeminiProvider {
35    pub fn new(api_key: String) -> Self {
36        Self::with_model(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string())
37    }
38
39    pub fn with_model(api_key: String, model: String) -> Self {
40        Self {
41            api_key,
42            http_client: HttpClient::new(),
43            base_url: urls::GEMINI_API_BASE.to_string(),
44            model,
45        }
46    }
47
48    pub fn from_config(
49        api_key: Option<String>,
50        model: Option<String>,
51        base_url: Option<String>,
52    ) -> Self {
53        let api_key_value = api_key.unwrap_or_default();
54        let mut provider = if let Some(model_value) = model {
55            Self::with_model(api_key_value, model_value)
56        } else {
57            Self::new(api_key_value)
58        };
59        if let Some(base) = base_url {
60            provider.base_url = base;
61        }
62        provider
63    }
64}
65
66#[async_trait]
67impl LLMProvider for GeminiProvider {
68    fn name(&self) -> &str {
69        "gemini"
70    }
71
72    fn supports_streaming(&self) -> bool {
73        true
74    }
75
76    fn supports_reasoning(&self, _model: &str) -> bool {
77        false
78    }
79
80    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
81        let gemini_request = self.convert_to_gemini_request(&request)?;
82
83        let url = format!(
84            "{}/models/{}:generateContent?key={}",
85            self.base_url, request.model, self.api_key
86        );
87
88        let response = self
89            .http_client
90            .post(&url)
91            .json(&gemini_request)
92            .send()
93            .await
94            .map_err(|e| {
95                let formatted_error =
96                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
97                LLMError::Network(formatted_error)
98            })?;
99
100        if !response.status().is_success() {
101            let status = response.status();
102            let error_text = response.text().await.unwrap_or_default();
103
104            // Handle specific HTTP status codes
105            if status.as_u16() == 429
106                || error_text.contains("insufficient_quota")
107                || error_text.contains("quota")
108                || error_text.contains("rate limit")
109            {
110                return Err(LLMError::RateLimit);
111            }
112
113            let formatted_error = error_display::format_llm_error(
114                "Gemini",
115                &format!("HTTP {}: {}", status, error_text),
116            );
117            return Err(LLMError::Provider(formatted_error));
118        }
119
120        let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
121            let formatted_error = error_display::format_llm_error(
122                "Gemini",
123                &format!("Failed to parse response: {}", e),
124            );
125            LLMError::Provider(formatted_error)
126        })?;
127
128        Self::convert_from_gemini_response(gemini_response)
129    }
130
131    async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
132        let gemini_request = self.convert_to_gemini_request(&request)?;
133
134        let url = format!(
135            "{}/models/{}:streamGenerateContent?key={}",
136            self.base_url, request.model, self.api_key
137        );
138
139        let response = self
140            .http_client
141            .post(&url)
142            .json(&gemini_request)
143            .send()
144            .await
145            .map_err(|e| {
146                let formatted_error =
147                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
148                LLMError::Network(formatted_error)
149            })?;
150
151        if !response.status().is_success() {
152            let status = response.status();
153            let error_text = response.text().await.unwrap_or_default();
154
155            if status.as_u16() == 401 || status.as_u16() == 403 {
156                let formatted_error = error_display::format_llm_error(
157                    "Gemini",
158                    &format!("HTTP {}: {}", status, error_text),
159                );
160                return Err(LLMError::Authentication(formatted_error));
161            }
162
163            if status.as_u16() == 429
164                || error_text.contains("insufficient_quota")
165                || error_text.contains("quota")
166                || error_text.contains("rate limit")
167            {
168                return Err(LLMError::RateLimit);
169            }
170
171            let formatted_error = error_display::format_llm_error(
172                "Gemini",
173                &format!("HTTP {}: {}", status, error_text),
174            );
175            return Err(LLMError::Provider(formatted_error));
176        }
177
178        let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
179        let completion_sender = event_tx.clone();
180
181        tokio::spawn(async move {
182            let mut processor = StreamingProcessor::new();
183            let token_sender = completion_sender.clone();
184            let mut aggregated_text = String::new();
185            let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
186                if chunk.is_empty() {
187                    return Ok(());
188                }
189
190                aggregated_text.push_str(chunk);
191
192                token_sender
193                    .send(Ok(LLMStreamEvent::Token {
194                        delta: chunk.to_string(),
195                    }))
196                    .map_err(|_| StreamingError::StreamingError {
197                        message: "Streaming consumer dropped".to_string(),
198                        partial_content: Some(chunk.to_string()),
199                    })?;
200                Ok(())
201            };
202
203            let result = processor.process_stream(response, &mut on_chunk).await;
204            match result {
205                Ok(mut streaming_response) => {
206                    if streaming_response.candidates.is_empty()
207                        && !aggregated_text.trim().is_empty()
208                    {
209                        streaming_response.candidates.push(StreamingCandidate {
210                            content: Content {
211                                role: "model".to_string(),
212                                parts: vec![Part::Text {
213                                    text: aggregated_text.clone(),
214                                }],
215                            },
216                            finish_reason: None,
217                            index: Some(0),
218                        });
219                    }
220
221                    match Self::convert_from_streaming_response(streaming_response) {
222                        Ok(final_response) => {
223                            let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
224                                response: final_response,
225                            }));
226                        }
227                        Err(err) => {
228                            let _ = completion_sender.send(Err(err));
229                        }
230                    }
231                }
232                Err(error) => {
233                    let mapped = Self::map_streaming_error(error);
234                    let _ = completion_sender.send(Err(mapped));
235                }
236            }
237        });
238
239        drop(event_tx);
240
241        let stream = {
242            let mut receiver = event_rx;
243            try_stream! {
244                while let Some(event) = receiver.recv().await {
245                    yield event?;
246                }
247            }
248        };
249
250        Ok(Box::pin(stream))
251    }
252
253    fn supported_models(&self) -> Vec<String> {
254        vec![
255            models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
256            models::google::GEMINI_2_5_PRO.to_string(),
257        ]
258    }
259
260    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
261        if !self.supported_models().contains(&request.model) {
262            let formatted_error = error_display::format_llm_error(
263                "Gemini",
264                &format!("Unsupported model: {}", request.model),
265            );
266            return Err(LLMError::InvalidRequest(formatted_error));
267        }
268        Ok(())
269    }
270}
271
272impl GeminiProvider {
273    fn convert_to_gemini_request(
274        &self,
275        request: &LLMRequest,
276    ) -> Result<GenerateContentRequest, LLMError> {
277        let mut call_map: HashMap<String, String> = HashMap::new();
278        for message in &request.messages {
279            if message.role == MessageRole::Assistant
280                && let Some(tool_calls) = &message.tool_calls
281            {
282                for tool_call in tool_calls {
283                    call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
284                }
285            }
286        }
287
288        let mut contents: Vec<Content> = Vec::new();
289        for message in &request.messages {
290            if message.role == MessageRole::System {
291                continue;
292            }
293
294            let mut parts: Vec<Part> = Vec::new();
295            if message.role != MessageRole::Tool && !message.content.is_empty() {
296                parts.push(Part::Text {
297                    text: message.content.clone(),
298                });
299            }
300
301            if message.role == MessageRole::Assistant
302                && let Some(tool_calls) = &message.tool_calls
303            {
304                for tool_call in tool_calls {
305                    let parsed_args = serde_json::from_str(&tool_call.function.arguments)
306                        .unwrap_or_else(|_| json!({}));
307                    parts.push(Part::FunctionCall {
308                        function_call: GeminiFunctionCall {
309                            name: tool_call.function.name.clone(),
310                            args: parsed_args,
311                            id: Some(tool_call.id.clone()),
312                        },
313                    });
314                }
315            }
316
317            if message.role == MessageRole::Tool {
318                if let Some(tool_call_id) = &message.tool_call_id {
319                    let func_name = call_map
320                        .get(tool_call_id)
321                        .cloned()
322                        .unwrap_or_else(|| tool_call_id.clone());
323                    let response_text = serde_json::from_str::<Value>(&message.content)
324                        .map(|value| {
325                            serde_json::to_string_pretty(&value)
326                                .unwrap_or_else(|_| message.content.clone())
327                        })
328                        .unwrap_or_else(|_| message.content.clone());
329
330                    let response_payload = json!({
331                        "name": func_name.clone(),
332                        "content": [{
333                            "text": response_text
334                        }]
335                    });
336
337                    parts.push(Part::FunctionResponse {
338                        function_response: FunctionResponse {
339                            name: func_name,
340                            response: response_payload,
341                        },
342                    });
343                } else if !message.content.is_empty() {
344                    parts.push(Part::Text {
345                        text: message.content.clone(),
346                    });
347                }
348            }
349
350            if !parts.is_empty() {
351                contents.push(Content {
352                    role: message.role.as_gemini_str().to_string(),
353                    parts,
354                });
355            }
356        }
357
358        let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
359            definitions
360                .iter()
361                .map(|tool| Tool {
362                    function_declarations: vec![FunctionDeclaration {
363                        name: tool.function.name.clone(),
364                        description: tool.function.description.clone(),
365                        parameters: tool.function.parameters.clone(),
366                    }],
367                })
368                .collect()
369        });
370
371        let mut generation_config = Map::new();
372        if let Some(max_tokens) = request.max_tokens {
373            generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
374        }
375        if let Some(temp) = request.temperature {
376            generation_config.insert("temperature".to_string(), json!(temp));
377        }
378        let has_tools = request
379            .tools
380            .as_ref()
381            .map(|defs| !defs.is_empty())
382            .unwrap_or(false);
383        let tool_config = if has_tools || request.tool_choice.is_some() {
384            Some(match request.tool_choice.as_ref() {
385                Some(ToolChoice::None) => ToolConfig {
386                    function_calling_config: FunctionCallingConfig::none(),
387                },
388                Some(ToolChoice::Any) => ToolConfig {
389                    function_calling_config: FunctionCallingConfig::any(),
390                },
391                Some(ToolChoice::Specific(spec)) => {
392                    let mut config = FunctionCallingConfig::any();
393                    if spec.tool_type == "function" {
394                        config.allowed_function_names = Some(vec![spec.function.name.clone()]);
395                    }
396                    ToolConfig {
397                        function_calling_config: config,
398                    }
399                }
400                _ => ToolConfig::auto(),
401            })
402        } else {
403            None
404        };
405
406        Ok(GenerateContentRequest {
407            contents,
408            tools,
409            tool_config,
410            system_instruction: request
411                .system_prompt
412                .as_ref()
413                .map(|text| SystemInstruction::new(text.clone())),
414            generation_config: if generation_config.is_empty() {
415                None
416            } else {
417                Some(Value::Object(generation_config))
418            },
419            reasoning_config: None,
420        })
421    }
422
423    fn convert_from_gemini_response(
424        response: GenerateContentResponse,
425    ) -> Result<LLMResponse, LLMError> {
426        let mut candidates = response.candidates.into_iter();
427        let candidate = candidates.next().ok_or_else(|| {
428            let formatted_error =
429                error_display::format_llm_error("Gemini", "No candidate in response");
430            LLMError::Provider(formatted_error)
431        })?;
432
433        if candidate.content.parts.is_empty() {
434            return Ok(LLMResponse {
435                content: Some(String::new()),
436                tool_calls: None,
437                usage: None,
438                finish_reason: FinishReason::Stop,
439                reasoning: None,
440            });
441        }
442
443        let mut text_content = String::new();
444        let mut tool_calls = Vec::new();
445
446        for part in candidate.content.parts {
447            match part {
448                Part::Text { text } => {
449                    text_content.push_str(&text);
450                }
451                Part::FunctionCall { function_call } => {
452                    let call_id = function_call.id.clone().unwrap_or_else(|| {
453                        format!(
454                            "call_{}_{}",
455                            std::time::SystemTime::now()
456                                .duration_since(std::time::UNIX_EPOCH)
457                                .unwrap_or_default()
458                                .as_nanos(),
459                            tool_calls.len()
460                        )
461                    });
462                    tool_calls.push(ToolCall {
463                        id: call_id,
464                        call_type: "function".to_string(),
465                        function: FunctionCall {
466                            name: function_call.name,
467                            arguments: serde_json::to_string(&function_call.args)
468                                .unwrap_or_else(|_| "{}".to_string()),
469                        },
470                    });
471                }
472                Part::FunctionResponse { .. } => {
473                    // Ignore echoed tool responses to avoid duplicating tool output
474                }
475            }
476        }
477
478        let finish_reason = match candidate.finish_reason.as_deref() {
479            Some("STOP") => FinishReason::Stop,
480            Some("MAX_TOKENS") => FinishReason::Length,
481            Some("SAFETY") => FinishReason::ContentFilter,
482            Some("FUNCTION_CALL") => FinishReason::ToolCalls,
483            Some(other) => FinishReason::Error(other.to_string()),
484            None => FinishReason::Stop,
485        };
486
487        Ok(LLMResponse {
488            content: if text_content.is_empty() {
489                None
490            } else {
491                Some(text_content)
492            },
493            tool_calls: if tool_calls.is_empty() {
494                None
495            } else {
496                Some(tool_calls)
497            },
498            usage: None,
499            finish_reason,
500            reasoning: None,
501        })
502    }
503
504    fn convert_from_streaming_response(
505        response: StreamingResponse,
506    ) -> Result<LLMResponse, LLMError> {
507        let converted_candidates: Vec<Candidate> = response
508            .candidates
509            .into_iter()
510            .map(|candidate| Candidate {
511                content: candidate.content,
512                finish_reason: candidate.finish_reason,
513            })
514            .collect();
515
516        let converted = GenerateContentResponse {
517            candidates: converted_candidates,
518            prompt_feedback: None,
519            usage_metadata: response.usage_metadata,
520        };
521
522        Self::convert_from_gemini_response(converted)
523    }
524
525    fn map_streaming_error(error: StreamingError) -> LLMError {
526        match error {
527            StreamingError::NetworkError { message, .. } => {
528                let formatted = error_display::format_llm_error(
529                    "Gemini",
530                    &format!("Network error: {}", message),
531                );
532                LLMError::Network(formatted)
533            }
534            StreamingError::ApiError {
535                status_code,
536                message,
537                ..
538            } => {
539                if status_code == 401 || status_code == 403 {
540                    let formatted = error_display::format_llm_error(
541                        "Gemini",
542                        &format!("HTTP {}: {}", status_code, message),
543                    );
544                    LLMError::Authentication(formatted)
545                } else if status_code == 429 {
546                    LLMError::RateLimit
547                } else {
548                    let formatted = error_display::format_llm_error(
549                        "Gemini",
550                        &format!("API error ({}): {}", status_code, message),
551                    );
552                    LLMError::Provider(formatted)
553                }
554            }
555            StreamingError::ParseError { message, .. } => {
556                let formatted =
557                    error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
558                LLMError::Provider(formatted)
559            }
560            StreamingError::TimeoutError {
561                operation,
562                duration,
563            } => {
564                let formatted = error_display::format_llm_error(
565                    "Gemini",
566                    &format!(
567                        "Streaming timeout during {} after {:?}",
568                        operation, duration
569                    ),
570                );
571                LLMError::Network(formatted)
572            }
573            StreamingError::ContentError { message } => {
574                let formatted = error_display::format_llm_error(
575                    "Gemini",
576                    &format!("Content error: {}", message),
577                );
578                LLMError::Provider(formatted)
579            }
580            StreamingError::StreamingError { message, .. } => {
581                let formatted = error_display::format_llm_error(
582                    "Gemini",
583                    &format!("Streaming error: {}", message),
584                );
585                LLMError::Provider(formatted)
586            }
587        }
588    }
589}
590
591#[async_trait]
592impl LLMClient for GeminiProvider {
593    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
594        // Check if the prompt is a serialized GenerateContentRequest
595        let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
596            // Try to parse as JSON GenerateContentRequest
597            match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
598                Ok(gemini_request) => {
599                    // Convert GenerateContentRequest to LLMRequest
600                    let mut messages = Vec::new();
601                    let mut system_prompt = None;
602
603                    // Convert contents to messages
604                    for content in &gemini_request.contents {
605                        let role = match content.role.as_str() {
606                            crate::config::constants::message_roles::USER => MessageRole::User,
607                            "model" => MessageRole::Assistant,
608                            crate::config::constants::message_roles::SYSTEM => {
609                                // Extract system message
610                                let text = content
611                                    .parts
612                                    .iter()
613                                    .filter_map(|part| part.as_text())
614                                    .collect::<Vec<_>>()
615                                    .join("");
616                                system_prompt = Some(text);
617                                continue;
618                            }
619                            _ => MessageRole::User, // Default to user
620                        };
621
622                        let content_text = content
623                            .parts
624                            .iter()
625                            .filter_map(|part| part.as_text())
626                            .collect::<Vec<_>>()
627                            .join("");
628
629                        messages.push(Message {
630                            role,
631                            content: content_text,
632                            tool_calls: None,
633                            tool_call_id: None,
634                        });
635                    }
636
637                    // Convert tools if present
638                    let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
639                        gemini_tools
640                            .iter()
641                            .flat_map(|tool| &tool.function_declarations)
642                            .map(|decl| crate::llm::provider::ToolDefinition {
643                                tool_type: "function".to_string(),
644                                function: crate::llm::provider::FunctionDefinition {
645                                    name: decl.name.clone(),
646                                    description: decl.description.clone(),
647                                    parameters: decl.parameters.clone(),
648                                },
649                            })
650                            .collect::<Vec<_>>()
651                    });
652
653                    let llm_request = LLMRequest {
654                        messages,
655                        system_prompt,
656                        tools,
657                        model: self.model.clone(),
658                        max_tokens: gemini_request
659                            .generation_config
660                            .as_ref()
661                            .and_then(|config| config.get("maxOutputTokens"))
662                            .and_then(|v| v.as_u64())
663                            .map(|v| v as u32),
664                        temperature: gemini_request
665                            .generation_config
666                            .as_ref()
667                            .and_then(|config| config.get("temperature"))
668                            .and_then(|v| v.as_f64())
669                            .map(|v| v as f32),
670                        stream: false,
671                        tool_choice: None,
672                        parallel_tool_calls: None,
673                        parallel_tool_config: None,
674                        reasoning_effort: None,
675                    };
676
677                    // Use the standard LLMProvider generate method
678                    let response = LLMProvider::generate(self, llm_request).await?;
679
680                    // If there are tool calls, include them in the response content as JSON
681                    let content = if let Some(tool_calls) = &response.tool_calls {
682                        if !tool_calls.is_empty() {
683                            // Create a JSON structure that the agent can parse
684                            let tool_call_json = json!({
685                                "tool_calls": tool_calls.iter().map(|tc| {
686                                    json!({
687                                        "function": {
688                                            "name": tc.function.name,
689                                            "arguments": tc.function.arguments
690                                        }
691                                    })
692                                }).collect::<Vec<_>>()
693                            });
694                            tool_call_json.to_string()
695                        } else {
696                            response.content.unwrap_or("".to_string())
697                        }
698                    } else {
699                        response.content.unwrap_or("".to_string())
700                    };
701
702                    return Ok(llm_types::LLMResponse {
703                        content,
704                        model: self.model.clone(),
705                        usage: response.usage.map(|u| llm_types::Usage {
706                            prompt_tokens: u.prompt_tokens as usize,
707                            completion_tokens: u.completion_tokens as usize,
708                            total_tokens: u.total_tokens as usize,
709                        }),
710                        reasoning: response.reasoning,
711                    });
712                }
713                Err(_) => {
714                    // Fallback: treat as regular prompt
715                    LLMRequest {
716                        messages: vec![Message {
717                            role: MessageRole::User,
718                            content: prompt.to_string(),
719                            tool_calls: None,
720                            tool_call_id: None,
721                        }],
722                        system_prompt: None,
723                        tools: None,
724                        model: self.model.clone(),
725                        max_tokens: None,
726                        temperature: None,
727                        stream: false,
728                        tool_choice: None,
729                        parallel_tool_calls: None,
730                        parallel_tool_config: None,
731                        reasoning_effort: None,
732                    }
733                }
734            }
735        } else {
736            // Fallback: treat as regular prompt
737            LLMRequest {
738                messages: vec![Message {
739                    role: MessageRole::User,
740                    content: prompt.to_string(),
741                    tool_calls: None,
742                    tool_call_id: None,
743                }],
744                system_prompt: None,
745                tools: None,
746                model: self.model.clone(),
747                max_tokens: None,
748                temperature: None,
749                stream: false,
750                tool_choice: None,
751                parallel_tool_calls: None,
752                parallel_tool_config: None,
753                reasoning_effort: None,
754            }
755        };
756
757        let response = LLMProvider::generate(self, request).await?;
758
759        Ok(llm_types::LLMResponse {
760            content: response.content.unwrap_or("".to_string()),
761            model: self.model.clone(),
762            usage: response.usage.map(|u| llm_types::Usage {
763                prompt_tokens: u.prompt_tokens as usize,
764                completion_tokens: u.completion_tokens as usize,
765                total_tokens: u.total_tokens as usize,
766            }),
767            reasoning: response.reasoning,
768        })
769    }
770
771    fn backend_kind(&self) -> llm_types::BackendKind {
772        llm_types::BackendKind::Gemini
773    }
774
775    fn model_id(&self) -> &str {
776        &self.model
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783    use crate::config::constants::models;
784    use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
785
786    #[test]
787    fn convert_to_gemini_request_maps_history_and_system_prompt() {
788        let provider = GeminiProvider::new("test-key".to_string());
789        let mut assistant_message = Message::assistant("Sure thing".to_string());
790        assistant_message.tool_calls = Some(vec![ToolCall::function(
791            "call_1".to_string(),
792            "list_files".to_string(),
793            json!({ "path": "." }).to_string(),
794        )]);
795
796        let tool_response =
797            Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
798
799        let tool_def = ToolDefinition::function(
800            "list_files".to_string(),
801            "List files".to_string(),
802            json!({
803                "type": "object",
804                "properties": {
805                    "path": { "type": "string" }
806                }
807            }),
808        );
809
810        let request = LLMRequest {
811            messages: vec![
812                Message::user("hello".to_string()),
813                assistant_message,
814                tool_response,
815            ],
816            system_prompt: Some("System prompt".to_string()),
817            tools: Some(vec![tool_def]),
818            model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
819            max_tokens: Some(256),
820            temperature: Some(0.4),
821            stream: false,
822            tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
823                tool_type: "function".to_string(),
824                function: SpecificFunctionChoice {
825                    name: "list_files".to_string(),
826                },
827            })),
828            parallel_tool_calls: None,
829            parallel_tool_config: None,
830            reasoning_effort: None,
831        };
832
833        let gemini_request = provider
834            .convert_to_gemini_request(&request)
835            .expect("conversion should succeed");
836
837        let system_instruction = gemini_request
838            .system_instruction
839            .expect("system instruction should be present");
840        assert!(matches!(
841            system_instruction.parts.as_slice(),
842            [Part::Text { text }] if text == "System prompt"
843        ));
844
845        assert_eq!(gemini_request.contents.len(), 3);
846        assert_eq!(gemini_request.contents[0].role, "user");
847        assert!(
848            gemini_request.contents[1]
849                .parts
850                .iter()
851                .any(|part| matches!(part, Part::FunctionCall { .. }))
852        );
853        let tool_part = gemini_request.contents[2]
854            .parts
855            .iter()
856            .find_map(|part| match part {
857                Part::FunctionResponse { function_response } => Some(function_response),
858                _ => None,
859            })
860            .expect("tool response part should exist");
861        assert_eq!(tool_part.name, "list_files");
862    }
863
864    #[test]
865    fn convert_from_gemini_response_extracts_tool_calls() {
866        let response = GenerateContentResponse {
867            candidates: vec![crate::gemini::Candidate {
868                content: Content {
869                    role: "model".to_string(),
870                    parts: vec![
871                        Part::Text {
872                            text: "Here you go".to_string(),
873                        },
874                        Part::FunctionCall {
875                            function_call: GeminiFunctionCall {
876                                name: "list_files".to_string(),
877                                args: json!({ "path": "." }),
878                                id: Some("call_1".to_string()),
879                            },
880                        },
881                    ],
882                },
883                finish_reason: Some("FUNCTION_CALL".to_string()),
884            }],
885            prompt_feedback: None,
886            usage_metadata: None,
887        };
888
889        let llm_response = GeminiProvider::convert_from_gemini_response(response)
890            .expect("conversion should succeed");
891
892        assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
893        let calls = llm_response
894            .tool_calls
895            .expect("tool call should be present");
896        assert_eq!(calls.len(), 1);
897        assert_eq!(calls[0].function.name, "list_files");
898        assert!(calls[0].function.arguments.contains("path"));
899        assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
900    }
901}