vtcode_core/llm/providers/
gemini.rs

1use crate::config::constants::{models, urls};
2use crate::config::core::{GeminiPromptCacheMode, GeminiPromptCacheSettings, PromptCachingConfig};
3use crate::gemini::function_calling::{
4    FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
5};
6use crate::gemini::models::SystemInstruction;
7use crate::gemini::streaming::{
8    StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
9};
10use crate::gemini::{
11    Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
12    Tool, ToolConfig,
13};
14use crate::llm::client::LLMClient;
15use crate::llm::error_display;
16use crate::llm::provider::{
17    FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
18    LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
19};
20use crate::llm::types as llm_types;
21use async_stream::try_stream;
22use async_trait::async_trait;
23use reqwest::Client as HttpClient;
24use serde_json::{Map, Value, json};
25use std::collections::HashMap;
26use tokio::sync::mpsc;
27
28pub struct GeminiProvider {
29    api_key: String,
30    http_client: HttpClient,
31    base_url: String,
32    model: String,
33    prompt_cache_enabled: bool,
34    prompt_cache_settings: GeminiPromptCacheSettings,
35}
36
37impl GeminiProvider {
38    pub fn new(api_key: String) -> Self {
39        Self::with_model_internal(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string(), None)
40    }
41
42    pub fn with_model(api_key: String, model: String) -> Self {
43        Self::with_model_internal(api_key, model, None)
44    }
45
46    pub fn from_config(
47        api_key: Option<String>,
48        model: Option<String>,
49        base_url: Option<String>,
50        prompt_cache: Option<PromptCachingConfig>,
51    ) -> Self {
52        let api_key_value = api_key.unwrap_or_default();
53        let mut provider = if let Some(model_value) = model {
54            Self::with_model_internal(api_key_value, model_value, prompt_cache)
55        } else {
56            Self::with_model_internal(
57                api_key_value,
58                models::GEMINI_2_5_FLASH_PREVIEW.to_string(),
59                prompt_cache,
60            )
61        };
62        if let Some(base) = base_url {
63            provider.base_url = base;
64        }
65        provider
66    }
67
68    fn with_model_internal(
69        api_key: String,
70        model: String,
71        prompt_cache: Option<PromptCachingConfig>,
72    ) -> Self {
73        let (prompt_cache_enabled, prompt_cache_settings) =
74            Self::extract_prompt_cache_settings(prompt_cache);
75
76        Self {
77            api_key,
78            http_client: HttpClient::new(),
79            base_url: urls::GEMINI_API_BASE.to_string(),
80            model,
81            prompt_cache_enabled,
82            prompt_cache_settings,
83        }
84    }
85
86    fn extract_prompt_cache_settings(
87        prompt_cache: Option<PromptCachingConfig>,
88    ) -> (bool, GeminiPromptCacheSettings) {
89        if let Some(cfg) = prompt_cache {
90            let provider_settings = cfg.providers.gemini;
91            let enabled = cfg.enabled
92                && provider_settings.enabled
93                && provider_settings.mode != GeminiPromptCacheMode::Off;
94            (enabled, provider_settings)
95        } else {
96            (false, GeminiPromptCacheSettings::default())
97        }
98    }
99}
100
101#[async_trait]
102impl LLMProvider for GeminiProvider {
103    fn name(&self) -> &str {
104        "gemini"
105    }
106
107    fn supports_streaming(&self) -> bool {
108        true
109    }
110
111    fn supports_reasoning(&self, _model: &str) -> bool {
112        false
113    }
114
115    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
116        let gemini_request = self.convert_to_gemini_request(&request)?;
117
118        let url = format!(
119            "{}/models/{}:generateContent?key={}",
120            self.base_url, request.model, self.api_key
121        );
122
123        let response = self
124            .http_client
125            .post(&url)
126            .json(&gemini_request)
127            .send()
128            .await
129            .map_err(|e| {
130                let formatted_error =
131                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
132                LLMError::Network(formatted_error)
133            })?;
134
135        if !response.status().is_success() {
136            let status = response.status();
137            let error_text = response.text().await.unwrap_or_default();
138
139            // Handle specific HTTP status codes
140            if status.as_u16() == 429
141                || error_text.contains("insufficient_quota")
142                || error_text.contains("quota")
143                || error_text.contains("rate limit")
144            {
145                return Err(LLMError::RateLimit);
146            }
147
148            let formatted_error = error_display::format_llm_error(
149                "Gemini",
150                &format!("HTTP {}: {}", status, error_text),
151            );
152            return Err(LLMError::Provider(formatted_error));
153        }
154
155        let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
156            let formatted_error = error_display::format_llm_error(
157                "Gemini",
158                &format!("Failed to parse response: {}", e),
159            );
160            LLMError::Provider(formatted_error)
161        })?;
162
163        Self::convert_from_gemini_response(gemini_response)
164    }
165
166    async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
167        let gemini_request = self.convert_to_gemini_request(&request)?;
168
169        let url = format!(
170            "{}/models/{}:streamGenerateContent?key={}",
171            self.base_url, request.model, self.api_key
172        );
173
174        let response = self
175            .http_client
176            .post(&url)
177            .json(&gemini_request)
178            .send()
179            .await
180            .map_err(|e| {
181                let formatted_error =
182                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
183                LLMError::Network(formatted_error)
184            })?;
185
186        if !response.status().is_success() {
187            let status = response.status();
188            let error_text = response.text().await.unwrap_or_default();
189
190            if status.as_u16() == 401 || status.as_u16() == 403 {
191                let formatted_error = error_display::format_llm_error(
192                    "Gemini",
193                    &format!("HTTP {}: {}", status, error_text),
194                );
195                return Err(LLMError::Authentication(formatted_error));
196            }
197
198            if status.as_u16() == 429
199                || error_text.contains("insufficient_quota")
200                || error_text.contains("quota")
201                || error_text.contains("rate limit")
202            {
203                return Err(LLMError::RateLimit);
204            }
205
206            let formatted_error = error_display::format_llm_error(
207                "Gemini",
208                &format!("HTTP {}: {}", status, error_text),
209            );
210            return Err(LLMError::Provider(formatted_error));
211        }
212
213        let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
214        let completion_sender = event_tx.clone();
215
216        tokio::spawn(async move {
217            let mut processor = StreamingProcessor::new();
218            let token_sender = completion_sender.clone();
219            let mut aggregated_text = String::new();
220            let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
221                if chunk.is_empty() {
222                    return Ok(());
223                }
224
225                aggregated_text.push_str(chunk);
226
227                token_sender
228                    .send(Ok(LLMStreamEvent::Token {
229                        delta: chunk.to_string(),
230                    }))
231                    .map_err(|_| StreamingError::StreamingError {
232                        message: "Streaming consumer dropped".to_string(),
233                        partial_content: Some(chunk.to_string()),
234                    })?;
235                Ok(())
236            };
237
238            let result = processor.process_stream(response, &mut on_chunk).await;
239            match result {
240                Ok(mut streaming_response) => {
241                    if streaming_response.candidates.is_empty()
242                        && !aggregated_text.trim().is_empty()
243                    {
244                        streaming_response.candidates.push(StreamingCandidate {
245                            content: Content {
246                                role: "model".to_string(),
247                                parts: vec![Part::Text {
248                                    text: aggregated_text.clone(),
249                                }],
250                            },
251                            finish_reason: None,
252                            index: Some(0),
253                        });
254                    }
255
256                    match Self::convert_from_streaming_response(streaming_response) {
257                        Ok(final_response) => {
258                            let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
259                                response: final_response,
260                            }));
261                        }
262                        Err(err) => {
263                            let _ = completion_sender.send(Err(err));
264                        }
265                    }
266                }
267                Err(error) => {
268                    let mapped = Self::map_streaming_error(error);
269                    let _ = completion_sender.send(Err(mapped));
270                }
271            }
272        });
273
274        drop(event_tx);
275
276        let stream = {
277            let mut receiver = event_rx;
278            try_stream! {
279                while let Some(event) = receiver.recv().await {
280                    yield event?;
281                }
282            }
283        };
284
285        Ok(Box::pin(stream))
286    }
287
288    fn supported_models(&self) -> Vec<String> {
289        vec![
290            models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
291            models::google::GEMINI_2_5_PRO.to_string(),
292        ]
293    }
294
295    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
296        if !self.supported_models().contains(&request.model) {
297            let formatted_error = error_display::format_llm_error(
298                "Gemini",
299                &format!("Unsupported model: {}", request.model),
300            );
301            return Err(LLMError::InvalidRequest(formatted_error));
302        }
303        Ok(())
304    }
305}
306
307impl GeminiProvider {
308    fn convert_to_gemini_request(
309        &self,
310        request: &LLMRequest,
311    ) -> Result<GenerateContentRequest, LLMError> {
312        if self.prompt_cache_enabled
313            && matches!(
314                self.prompt_cache_settings.mode,
315                GeminiPromptCacheMode::Explicit
316            )
317        {
318            // Explicit cache handling requires separate cache lifecycle APIs which are
319            // coordinated outside of the request payload. Placeholder ensures we surface
320            // configuration usage even when implicit mode is active.
321        }
322
323        let mut call_map: HashMap<String, String> = HashMap::new();
324        for message in &request.messages {
325            if message.role == MessageRole::Assistant
326                && let Some(tool_calls) = &message.tool_calls
327            {
328                for tool_call in tool_calls {
329                    call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
330                }
331            }
332        }
333
334        let mut contents: Vec<Content> = Vec::new();
335        for message in &request.messages {
336            if message.role == MessageRole::System {
337                continue;
338            }
339
340            let mut parts: Vec<Part> = Vec::new();
341            if message.role != MessageRole::Tool && !message.content.is_empty() {
342                parts.push(Part::Text {
343                    text: message.content.clone(),
344                });
345            }
346
347            if message.role == MessageRole::Assistant
348                && let Some(tool_calls) = &message.tool_calls
349            {
350                for tool_call in tool_calls {
351                    let parsed_args = serde_json::from_str(&tool_call.function.arguments)
352                        .unwrap_or_else(|_| json!({}));
353                    parts.push(Part::FunctionCall {
354                        function_call: GeminiFunctionCall {
355                            name: tool_call.function.name.clone(),
356                            args: parsed_args,
357                            id: Some(tool_call.id.clone()),
358                        },
359                    });
360                }
361            }
362
363            if message.role == MessageRole::Tool {
364                if let Some(tool_call_id) = &message.tool_call_id {
365                    let func_name = call_map
366                        .get(tool_call_id)
367                        .cloned()
368                        .unwrap_or_else(|| tool_call_id.clone());
369                    let response_text = serde_json::from_str::<Value>(&message.content)
370                        .map(|value| {
371                            serde_json::to_string_pretty(&value)
372                                .unwrap_or_else(|_| message.content.clone())
373                        })
374                        .unwrap_or_else(|_| message.content.clone());
375
376                    let response_payload = json!({
377                        "name": func_name.clone(),
378                        "content": [{
379                            "text": response_text
380                        }]
381                    });
382
383                    parts.push(Part::FunctionResponse {
384                        function_response: FunctionResponse {
385                            name: func_name,
386                            response: response_payload,
387                        },
388                    });
389                } else if !message.content.is_empty() {
390                    parts.push(Part::Text {
391                        text: message.content.clone(),
392                    });
393                }
394            }
395
396            if !parts.is_empty() {
397                contents.push(Content {
398                    role: message.role.as_gemini_str().to_string(),
399                    parts,
400                });
401            }
402        }
403
404        let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
405            definitions
406                .iter()
407                .map(|tool| Tool {
408                    function_declarations: vec![FunctionDeclaration {
409                        name: tool.function.name.clone(),
410                        description: tool.function.description.clone(),
411                        parameters: tool.function.parameters.clone(),
412                    }],
413                })
414                .collect()
415        });
416
417        let mut generation_config = Map::new();
418        if let Some(max_tokens) = request.max_tokens {
419            generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
420        }
421        if let Some(temp) = request.temperature {
422            generation_config.insert("temperature".to_string(), json!(temp));
423        }
424        let has_tools = request
425            .tools
426            .as_ref()
427            .map(|defs| !defs.is_empty())
428            .unwrap_or(false);
429        let tool_config = if has_tools || request.tool_choice.is_some() {
430            Some(match request.tool_choice.as_ref() {
431                Some(ToolChoice::None) => ToolConfig {
432                    function_calling_config: FunctionCallingConfig::none(),
433                },
434                Some(ToolChoice::Any) => ToolConfig {
435                    function_calling_config: FunctionCallingConfig::any(),
436                },
437                Some(ToolChoice::Specific(spec)) => {
438                    let mut config = FunctionCallingConfig::any();
439                    if spec.tool_type == "function" {
440                        config.allowed_function_names = Some(vec![spec.function.name.clone()]);
441                    }
442                    ToolConfig {
443                        function_calling_config: config,
444                    }
445                }
446                _ => ToolConfig::auto(),
447            })
448        } else {
449            None
450        };
451
452        Ok(GenerateContentRequest {
453            contents,
454            tools,
455            tool_config,
456            system_instruction: request
457                .system_prompt
458                .as_ref()
459                .map(|text| SystemInstruction::new(text.clone())),
460            generation_config: if generation_config.is_empty() {
461                None
462            } else {
463                Some(Value::Object(generation_config))
464            },
465            reasoning_config: None,
466        })
467    }
468
469    fn convert_from_gemini_response(
470        response: GenerateContentResponse,
471    ) -> Result<LLMResponse, LLMError> {
472        let mut candidates = response.candidates.into_iter();
473        let candidate = candidates.next().ok_or_else(|| {
474            let formatted_error =
475                error_display::format_llm_error("Gemini", "No candidate in response");
476            LLMError::Provider(formatted_error)
477        })?;
478
479        if candidate.content.parts.is_empty() {
480            return Ok(LLMResponse {
481                content: Some(String::new()),
482                tool_calls: None,
483                usage: None,
484                finish_reason: FinishReason::Stop,
485                reasoning: None,
486            });
487        }
488
489        let mut text_content = String::new();
490        let mut tool_calls = Vec::new();
491
492        for part in candidate.content.parts {
493            match part {
494                Part::Text { text } => {
495                    text_content.push_str(&text);
496                }
497                Part::FunctionCall { function_call } => {
498                    let call_id = function_call.id.clone().unwrap_or_else(|| {
499                        format!(
500                            "call_{}_{}",
501                            std::time::SystemTime::now()
502                                .duration_since(std::time::UNIX_EPOCH)
503                                .unwrap_or_default()
504                                .as_nanos(),
505                            tool_calls.len()
506                        )
507                    });
508                    tool_calls.push(ToolCall {
509                        id: call_id,
510                        call_type: "function".to_string(),
511                        function: FunctionCall {
512                            name: function_call.name,
513                            arguments: serde_json::to_string(&function_call.args)
514                                .unwrap_or_else(|_| "{}".to_string()),
515                        },
516                    });
517                }
518                Part::FunctionResponse { .. } => {
519                    // Ignore echoed tool responses to avoid duplicating tool output
520                }
521            }
522        }
523
524        let finish_reason = match candidate.finish_reason.as_deref() {
525            Some("STOP") => FinishReason::Stop,
526            Some("MAX_TOKENS") => FinishReason::Length,
527            Some("SAFETY") => FinishReason::ContentFilter,
528            Some("FUNCTION_CALL") => FinishReason::ToolCalls,
529            Some(other) => FinishReason::Error(other.to_string()),
530            None => FinishReason::Stop,
531        };
532
533        Ok(LLMResponse {
534            content: if text_content.is_empty() {
535                None
536            } else {
537                Some(text_content)
538            },
539            tool_calls: if tool_calls.is_empty() {
540                None
541            } else {
542                Some(tool_calls)
543            },
544            usage: None,
545            finish_reason,
546            reasoning: None,
547        })
548    }
549
550    fn convert_from_streaming_response(
551        response: StreamingResponse,
552    ) -> Result<LLMResponse, LLMError> {
553        let converted_candidates: Vec<Candidate> = response
554            .candidates
555            .into_iter()
556            .map(|candidate| Candidate {
557                content: candidate.content,
558                finish_reason: candidate.finish_reason,
559            })
560            .collect();
561
562        let converted = GenerateContentResponse {
563            candidates: converted_candidates,
564            prompt_feedback: None,
565            usage_metadata: response.usage_metadata,
566        };
567
568        Self::convert_from_gemini_response(converted)
569    }
570
571    fn map_streaming_error(error: StreamingError) -> LLMError {
572        match error {
573            StreamingError::NetworkError { message, .. } => {
574                let formatted = error_display::format_llm_error(
575                    "Gemini",
576                    &format!("Network error: {}", message),
577                );
578                LLMError::Network(formatted)
579            }
580            StreamingError::ApiError {
581                status_code,
582                message,
583                ..
584            } => {
585                if status_code == 401 || status_code == 403 {
586                    let formatted = error_display::format_llm_error(
587                        "Gemini",
588                        &format!("HTTP {}: {}", status_code, message),
589                    );
590                    LLMError::Authentication(formatted)
591                } else if status_code == 429 {
592                    LLMError::RateLimit
593                } else {
594                    let formatted = error_display::format_llm_error(
595                        "Gemini",
596                        &format!("API error ({}): {}", status_code, message),
597                    );
598                    LLMError::Provider(formatted)
599                }
600            }
601            StreamingError::ParseError { message, .. } => {
602                let formatted =
603                    error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
604                LLMError::Provider(formatted)
605            }
606            StreamingError::TimeoutError {
607                operation,
608                duration,
609            } => {
610                let formatted = error_display::format_llm_error(
611                    "Gemini",
612                    &format!(
613                        "Streaming timeout during {} after {:?}",
614                        operation, duration
615                    ),
616                );
617                LLMError::Network(formatted)
618            }
619            StreamingError::ContentError { message } => {
620                let formatted = error_display::format_llm_error(
621                    "Gemini",
622                    &format!("Content error: {}", message),
623                );
624                LLMError::Provider(formatted)
625            }
626            StreamingError::StreamingError { message, .. } => {
627                let formatted = error_display::format_llm_error(
628                    "Gemini",
629                    &format!("Streaming error: {}", message),
630                );
631                LLMError::Provider(formatted)
632            }
633        }
634    }
635}
636
637#[async_trait]
638impl LLMClient for GeminiProvider {
639    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
640        // Check if the prompt is a serialized GenerateContentRequest
641        let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
642            // Try to parse as JSON GenerateContentRequest
643            match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
644                Ok(gemini_request) => {
645                    // Convert GenerateContentRequest to LLMRequest
646                    let mut messages = Vec::new();
647                    let mut system_prompt = None;
648
649                    // Convert contents to messages
650                    for content in &gemini_request.contents {
651                        let role = match content.role.as_str() {
652                            crate::config::constants::message_roles::USER => MessageRole::User,
653                            "model" => MessageRole::Assistant,
654                            crate::config::constants::message_roles::SYSTEM => {
655                                // Extract system message
656                                let text = content
657                                    .parts
658                                    .iter()
659                                    .filter_map(|part| part.as_text())
660                                    .collect::<Vec<_>>()
661                                    .join("");
662                                system_prompt = Some(text);
663                                continue;
664                            }
665                            _ => MessageRole::User, // Default to user
666                        };
667
668                        let content_text = content
669                            .parts
670                            .iter()
671                            .filter_map(|part| part.as_text())
672                            .collect::<Vec<_>>()
673                            .join("");
674
675                        messages.push(Message {
676                            role,
677                            content: content_text,
678                            tool_calls: None,
679                            tool_call_id: None,
680                        });
681                    }
682
683                    // Convert tools if present
684                    let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
685                        gemini_tools
686                            .iter()
687                            .flat_map(|tool| &tool.function_declarations)
688                            .map(|decl| crate::llm::provider::ToolDefinition {
689                                tool_type: "function".to_string(),
690                                function: crate::llm::provider::FunctionDefinition {
691                                    name: decl.name.clone(),
692                                    description: decl.description.clone(),
693                                    parameters: decl.parameters.clone(),
694                                },
695                            })
696                            .collect::<Vec<_>>()
697                    });
698
699                    let llm_request = LLMRequest {
700                        messages,
701                        system_prompt,
702                        tools,
703                        model: self.model.clone(),
704                        max_tokens: gemini_request
705                            .generation_config
706                            .as_ref()
707                            .and_then(|config| config.get("maxOutputTokens"))
708                            .and_then(|v| v.as_u64())
709                            .map(|v| v as u32),
710                        temperature: gemini_request
711                            .generation_config
712                            .as_ref()
713                            .and_then(|config| config.get("temperature"))
714                            .and_then(|v| v.as_f64())
715                            .map(|v| v as f32),
716                        stream: false,
717                        tool_choice: None,
718                        parallel_tool_calls: None,
719                        parallel_tool_config: None,
720                        reasoning_effort: None,
721                    };
722
723                    // Use the standard LLMProvider generate method
724                    let response = LLMProvider::generate(self, llm_request).await?;
725
726                    // If there are tool calls, include them in the response content as JSON
727                    let content = if let Some(tool_calls) = &response.tool_calls {
728                        if !tool_calls.is_empty() {
729                            // Create a JSON structure that the agent can parse
730                            let tool_call_json = json!({
731                                "tool_calls": tool_calls.iter().map(|tc| {
732                                    json!({
733                                        "function": {
734                                            "name": tc.function.name,
735                                            "arguments": tc.function.arguments
736                                        }
737                                    })
738                                }).collect::<Vec<_>>()
739                            });
740                            tool_call_json.to_string()
741                        } else {
742                            response.content.unwrap_or("".to_string())
743                        }
744                    } else {
745                        response.content.unwrap_or("".to_string())
746                    };
747
748                    return Ok(llm_types::LLMResponse {
749                        content,
750                        model: self.model.clone(),
751                        usage: response.usage.map(|u| llm_types::Usage {
752                            prompt_tokens: u.prompt_tokens as usize,
753                            completion_tokens: u.completion_tokens as usize,
754                            total_tokens: u.total_tokens as usize,
755                            cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
756                            cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
757                            cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
758                        }),
759                        reasoning: response.reasoning,
760                    });
761                }
762                Err(_) => {
763                    // Fallback: treat as regular prompt
764                    LLMRequest {
765                        messages: vec![Message {
766                            role: MessageRole::User,
767                            content: prompt.to_string(),
768                            tool_calls: None,
769                            tool_call_id: None,
770                        }],
771                        system_prompt: None,
772                        tools: None,
773                        model: self.model.clone(),
774                        max_tokens: None,
775                        temperature: None,
776                        stream: false,
777                        tool_choice: None,
778                        parallel_tool_calls: None,
779                        parallel_tool_config: None,
780                        reasoning_effort: None,
781                    }
782                }
783            }
784        } else {
785            // Fallback: treat as regular prompt
786            LLMRequest {
787                messages: vec![Message {
788                    role: MessageRole::User,
789                    content: prompt.to_string(),
790                    tool_calls: None,
791                    tool_call_id: None,
792                }],
793                system_prompt: None,
794                tools: None,
795                model: self.model.clone(),
796                max_tokens: None,
797                temperature: None,
798                stream: false,
799                tool_choice: None,
800                parallel_tool_calls: None,
801                parallel_tool_config: None,
802                reasoning_effort: None,
803            }
804        };
805
806        let response = LLMProvider::generate(self, request).await?;
807
808        Ok(llm_types::LLMResponse {
809            content: response.content.unwrap_or("".to_string()),
810            model: self.model.clone(),
811            usage: response.usage.map(|u| llm_types::Usage {
812                prompt_tokens: u.prompt_tokens as usize,
813                completion_tokens: u.completion_tokens as usize,
814                total_tokens: u.total_tokens as usize,
815                cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
816                cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
817                cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
818            }),
819            reasoning: response.reasoning,
820        })
821    }
822
823    fn backend_kind(&self) -> llm_types::BackendKind {
824        llm_types::BackendKind::Gemini
825    }
826
827    fn model_id(&self) -> &str {
828        &self.model
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835    use crate::config::constants::models;
836    use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
837
838    #[test]
839    fn convert_to_gemini_request_maps_history_and_system_prompt() {
840        let provider = GeminiProvider::new("test-key".to_string());
841        let mut assistant_message = Message::assistant("Sure thing".to_string());
842        assistant_message.tool_calls = Some(vec![ToolCall::function(
843            "call_1".to_string(),
844            "list_files".to_string(),
845            json!({ "path": "." }).to_string(),
846        )]);
847
848        let tool_response =
849            Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
850
851        let tool_def = ToolDefinition::function(
852            "list_files".to_string(),
853            "List files".to_string(),
854            json!({
855                "type": "object",
856                "properties": {
857                    "path": { "type": "string" }
858                }
859            }),
860        );
861
862        let request = LLMRequest {
863            messages: vec![
864                Message::user("hello".to_string()),
865                assistant_message,
866                tool_response,
867            ],
868            system_prompt: Some("System prompt".to_string()),
869            tools: Some(vec![tool_def]),
870            model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
871            max_tokens: Some(256),
872            temperature: Some(0.4),
873            stream: false,
874            tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
875                tool_type: "function".to_string(),
876                function: SpecificFunctionChoice {
877                    name: "list_files".to_string(),
878                },
879            })),
880            parallel_tool_calls: None,
881            parallel_tool_config: None,
882            reasoning_effort: None,
883        };
884
885        let gemini_request = provider
886            .convert_to_gemini_request(&request)
887            .expect("conversion should succeed");
888
889        let system_instruction = gemini_request
890            .system_instruction
891            .expect("system instruction should be present");
892        assert!(matches!(
893            system_instruction.parts.as_slice(),
894            [Part::Text { text }] if text == "System prompt"
895        ));
896
897        assert_eq!(gemini_request.contents.len(), 3);
898        assert_eq!(gemini_request.contents[0].role, "user");
899        assert!(
900            gemini_request.contents[1]
901                .parts
902                .iter()
903                .any(|part| matches!(part, Part::FunctionCall { .. }))
904        );
905        let tool_part = gemini_request.contents[2]
906            .parts
907            .iter()
908            .find_map(|part| match part {
909                Part::FunctionResponse { function_response } => Some(function_response),
910                _ => None,
911            })
912            .expect("tool response part should exist");
913        assert_eq!(tool_part.name, "list_files");
914    }
915
916    #[test]
917    fn convert_from_gemini_response_extracts_tool_calls() {
918        let response = GenerateContentResponse {
919            candidates: vec![crate::gemini::Candidate {
920                content: Content {
921                    role: "model".to_string(),
922                    parts: vec![
923                        Part::Text {
924                            text: "Here you go".to_string(),
925                        },
926                        Part::FunctionCall {
927                            function_call: GeminiFunctionCall {
928                                name: "list_files".to_string(),
929                                args: json!({ "path": "." }),
930                                id: Some("call_1".to_string()),
931                            },
932                        },
933                    ],
934                },
935                finish_reason: Some("FUNCTION_CALL".to_string()),
936            }],
937            prompt_feedback: None,
938            usage_metadata: None,
939        };
940
941        let llm_response = GeminiProvider::convert_from_gemini_response(response)
942            .expect("conversion should succeed");
943
944        assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
945        let calls = llm_response
946            .tool_calls
947            .expect("tool call should be present");
948        assert_eq!(calls.len(), 1);
949        assert_eq!(calls[0].function.name, "list_files");
950        assert!(calls[0].function.arguments.contains("path"));
951        assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
952    }
953}