vtcode_core/llm/providers/
gemini.rs

1use crate::config::constants::{models, urls};
2use crate::config::core::{GeminiPromptCacheMode, GeminiPromptCacheSettings, PromptCachingConfig};
3use crate::gemini::function_calling::{
4    FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
5};
6use crate::gemini::models::SystemInstruction;
7use crate::gemini::streaming::{
8    StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
9};
10use crate::gemini::{
11    Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
12    Tool, ToolConfig,
13};
14use crate::llm::client::LLMClient;
15use crate::llm::error_display;
16use crate::llm::provider::{
17    FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
18    LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
19};
20use crate::llm::types as llm_types;
21use async_stream::try_stream;
22use async_trait::async_trait;
23use reqwest::Client as HttpClient;
24use serde_json::{Map, Value, json};
25use std::collections::HashMap;
26use tokio::sync::mpsc;
27
28pub struct GeminiProvider {
29    api_key: String,
30    http_client: HttpClient,
31    base_url: String,
32    model: String,
33    prompt_cache_enabled: bool,
34    prompt_cache_settings: GeminiPromptCacheSettings,
35}
36
37impl GeminiProvider {
38    pub fn new(api_key: String) -> Self {
39        Self::with_model_internal(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string(), None)
40    }
41
42    pub fn with_model(api_key: String, model: String) -> Self {
43        Self::with_model_internal(api_key, model, None)
44    }
45
46    pub fn from_config(
47        api_key: Option<String>,
48        model: Option<String>,
49        base_url: Option<String>,
50        prompt_cache: Option<PromptCachingConfig>,
51    ) -> Self {
52        let api_key_value = api_key.unwrap_or_default();
53        let mut provider = if let Some(model_value) = model {
54            Self::with_model_internal(api_key_value, model_value, prompt_cache)
55        } else {
56            Self::with_model_internal(
57                api_key_value,
58                models::GEMINI_2_5_FLASH_PREVIEW.to_string(),
59                prompt_cache,
60            )
61        };
62        if let Some(base) = base_url {
63            provider.base_url = base;
64        }
65        provider
66    }
67
68    fn with_model_internal(
69        api_key: String,
70        model: String,
71        prompt_cache: Option<PromptCachingConfig>,
72    ) -> Self {
73        let (prompt_cache_enabled, prompt_cache_settings) =
74            Self::extract_prompt_cache_settings(prompt_cache);
75
76        Self {
77            api_key,
78            http_client: HttpClient::new(),
79            base_url: urls::GEMINI_API_BASE.to_string(),
80            model,
81            prompt_cache_enabled,
82            prompt_cache_settings,
83        }
84    }
85
86    fn extract_prompt_cache_settings(
87        prompt_cache: Option<PromptCachingConfig>,
88    ) -> (bool, GeminiPromptCacheSettings) {
89        if let Some(cfg) = prompt_cache {
90            let provider_settings = cfg.providers.gemini;
91            let enabled = cfg.enabled
92                && provider_settings.enabled
93                && provider_settings.mode != GeminiPromptCacheMode::Off;
94            (enabled, provider_settings)
95        } else {
96            (false, GeminiPromptCacheSettings::default())
97        }
98    }
99}
100
101#[async_trait]
102impl LLMProvider for GeminiProvider {
103    fn name(&self) -> &str {
104        "gemini"
105    }
106
107    fn supports_streaming(&self) -> bool {
108        true
109    }
110
111    fn supports_reasoning(&self, _model: &str) -> bool {
112        false
113    }
114
115    async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
116        let gemini_request = self.convert_to_gemini_request(&request)?;
117
118        let url = format!(
119            "{}/models/{}:generateContent?key={}",
120            self.base_url, request.model, self.api_key
121        );
122
123        let response = self
124            .http_client
125            .post(&url)
126            .json(&gemini_request)
127            .send()
128            .await
129            .map_err(|e| {
130                let formatted_error =
131                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
132                LLMError::Network(formatted_error)
133            })?;
134
135        if !response.status().is_success() {
136            let status = response.status();
137            let error_text = response.text().await.unwrap_or_default();
138
139            // Handle specific HTTP status codes
140            if status.as_u16() == 429
141                || error_text.contains("insufficient_quota")
142                || error_text.contains("quota")
143                || error_text.contains("rate limit")
144            {
145                return Err(LLMError::RateLimit);
146            }
147
148            let formatted_error = error_display::format_llm_error(
149                "Gemini",
150                &format!("HTTP {}: {}", status, error_text),
151            );
152            return Err(LLMError::Provider(formatted_error));
153        }
154
155        let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
156            let formatted_error = error_display::format_llm_error(
157                "Gemini",
158                &format!("Failed to parse response: {}", e),
159            );
160            LLMError::Provider(formatted_error)
161        })?;
162
163        Self::convert_from_gemini_response(gemini_response)
164    }
165
166    async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
167        let gemini_request = self.convert_to_gemini_request(&request)?;
168
169        let url = format!(
170            "{}/models/{}:streamGenerateContent?key={}",
171            self.base_url, request.model, self.api_key
172        );
173
174        let response = self
175            .http_client
176            .post(&url)
177            .json(&gemini_request)
178            .send()
179            .await
180            .map_err(|e| {
181                let formatted_error =
182                    error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
183                LLMError::Network(formatted_error)
184            })?;
185
186        if !response.status().is_success() {
187            let status = response.status();
188            let error_text = response.text().await.unwrap_or_default();
189
190            if status.as_u16() == 401 || status.as_u16() == 403 {
191                let formatted_error = error_display::format_llm_error(
192                    "Gemini",
193                    &format!("HTTP {}: {}", status, error_text),
194                );
195                return Err(LLMError::Authentication(formatted_error));
196            }
197
198            if status.as_u16() == 429
199                || error_text.contains("insufficient_quota")
200                || error_text.contains("quota")
201                || error_text.contains("rate limit")
202            {
203                return Err(LLMError::RateLimit);
204            }
205
206            let formatted_error = error_display::format_llm_error(
207                "Gemini",
208                &format!("HTTP {}: {}", status, error_text),
209            );
210            return Err(LLMError::Provider(formatted_error));
211        }
212
213        let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
214        let completion_sender = event_tx.clone();
215
216        tokio::spawn(async move {
217            let mut processor = StreamingProcessor::new();
218            let token_sender = completion_sender.clone();
219            let mut aggregated_text = String::new();
220            let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
221                if chunk.is_empty() {
222                    return Ok(());
223                }
224
225                aggregated_text.push_str(chunk);
226
227                token_sender
228                    .send(Ok(LLMStreamEvent::Token {
229                        delta: chunk.to_string(),
230                    }))
231                    .map_err(|_| StreamingError::StreamingError {
232                        message: "Streaming consumer dropped".to_string(),
233                        partial_content: Some(chunk.to_string()),
234                    })?;
235                Ok(())
236            };
237
238            let result = processor.process_stream(response, &mut on_chunk).await;
239            match result {
240                Ok(mut streaming_response) => {
241                    if streaming_response.candidates.is_empty()
242                        && !aggregated_text.trim().is_empty()
243                    {
244                        streaming_response.candidates.push(StreamingCandidate {
245                            content: Content {
246                                role: "model".to_string(),
247                                parts: vec![Part::Text {
248                                    text: aggregated_text.clone(),
249                                }],
250                            },
251                            finish_reason: None,
252                            index: Some(0),
253                        });
254                    }
255
256                    match Self::convert_from_streaming_response(streaming_response) {
257                        Ok(final_response) => {
258                            let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
259                                response: final_response,
260                            }));
261                        }
262                        Err(err) => {
263                            let _ = completion_sender.send(Err(err));
264                        }
265                    }
266                }
267                Err(error) => {
268                    let mapped = Self::map_streaming_error(error);
269                    let _ = completion_sender.send(Err(mapped));
270                }
271            }
272        });
273
274        drop(event_tx);
275
276        let stream = {
277            let mut receiver = event_rx;
278            try_stream! {
279                while let Some(event) = receiver.recv().await {
280                    yield event?;
281                }
282            }
283        };
284
285        Ok(Box::pin(stream))
286    }
287
288    fn supported_models(&self) -> Vec<String> {
289        vec![
290            models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
291            models::google::GEMINI_2_5_PRO.to_string(),
292        ]
293    }
294
295    fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
296        if !self.supported_models().contains(&request.model) {
297            let formatted_error = error_display::format_llm_error(
298                "Gemini",
299                &format!("Unsupported model: {}", request.model),
300            );
301            return Err(LLMError::InvalidRequest(formatted_error));
302        }
303        Ok(())
304    }
305}
306
307impl GeminiProvider {
308    fn convert_to_gemini_request(
309        &self,
310        request: &LLMRequest,
311    ) -> Result<GenerateContentRequest, LLMError> {
312        if self.prompt_cache_enabled
313            && matches!(
314                self.prompt_cache_settings.mode,
315                GeminiPromptCacheMode::Explicit
316            )
317        {
318            // Explicit cache handling requires separate cache lifecycle APIs which are
319            // coordinated outside of the request payload. Placeholder ensures we surface
320            // configuration usage even when implicit mode is active.
321        }
322
323        let mut call_map: HashMap<String, String> = HashMap::new();
324        for message in &request.messages {
325            if message.role == MessageRole::Assistant
326                && let Some(tool_calls) = &message.tool_calls
327            {
328                for tool_call in tool_calls {
329                    call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
330                }
331            }
332        }
333
334        let mut contents: Vec<Content> = Vec::new();
335        for message in &request.messages {
336            if message.role == MessageRole::System {
337                continue;
338            }
339
340            let mut parts: Vec<Part> = Vec::new();
341            if message.role != MessageRole::Tool && !message.content.is_empty() {
342                parts.push(Part::Text {
343                    text: message.content.clone(),
344                });
345            }
346
347            if message.role == MessageRole::Assistant
348                && let Some(tool_calls) = &message.tool_calls
349            {
350                for tool_call in tool_calls {
351                    let parsed_args = serde_json::from_str(&tool_call.function.arguments)
352                        .unwrap_or_else(|_| json!({}));
353                    parts.push(Part::FunctionCall {
354                        function_call: GeminiFunctionCall {
355                            name: tool_call.function.name.clone(),
356                            args: parsed_args,
357                            id: Some(tool_call.id.clone()),
358                        },
359                    });
360                }
361            }
362
363            if message.role == MessageRole::Tool {
364                if let Some(tool_call_id) = &message.tool_call_id {
365                    let func_name = call_map
366                        .get(tool_call_id)
367                        .cloned()
368                        .unwrap_or_else(|| tool_call_id.clone());
369                    let response_text = serde_json::from_str::<Value>(&message.content)
370                        .map(|value| {
371                            serde_json::to_string_pretty(&value)
372                                .unwrap_or_else(|_| message.content.clone())
373                        })
374                        .unwrap_or_else(|_| message.content.clone());
375
376                    let response_payload = json!({
377                        "name": func_name.clone(),
378                        "content": [{
379                            "text": response_text
380                        }]
381                    });
382
383                    parts.push(Part::FunctionResponse {
384                        function_response: FunctionResponse {
385                            name: func_name,
386                            response: response_payload,
387                        },
388                    });
389                } else if !message.content.is_empty() {
390                    parts.push(Part::Text {
391                        text: message.content.clone(),
392                    });
393                }
394            }
395
396            if !parts.is_empty() {
397                contents.push(Content {
398                    role: message.role.as_gemini_str().to_string(),
399                    parts,
400                });
401            }
402        }
403
404        let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
405            definitions
406                .iter()
407                .map(|tool| Tool {
408                    function_declarations: vec![FunctionDeclaration {
409                        name: tool.function.name.clone(),
410                        description: tool.function.description.clone(),
411                        parameters: sanitize_function_parameters(tool.function.parameters.clone()),
412                    }],
413                })
414                .collect()
415        });
416
417        let mut generation_config = Map::new();
418        if let Some(max_tokens) = request.max_tokens {
419            generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
420        }
421        if let Some(temp) = request.temperature {
422            generation_config.insert("temperature".to_string(), json!(temp));
423        }
424        let has_tools = request
425            .tools
426            .as_ref()
427            .map(|defs| !defs.is_empty())
428            .unwrap_or(false);
429        let tool_config = if has_tools || request.tool_choice.is_some() {
430            Some(match request.tool_choice.as_ref() {
431                Some(ToolChoice::None) => ToolConfig {
432                    function_calling_config: FunctionCallingConfig::none(),
433                },
434                Some(ToolChoice::Any) => ToolConfig {
435                    function_calling_config: FunctionCallingConfig::any(),
436                },
437                Some(ToolChoice::Specific(spec)) => {
438                    let mut config = FunctionCallingConfig::any();
439                    if spec.tool_type == "function" {
440                        config.allowed_function_names = Some(vec![spec.function.name.clone()]);
441                    }
442                    ToolConfig {
443                        function_calling_config: config,
444                    }
445                }
446                _ => ToolConfig::auto(),
447            })
448        } else {
449            None
450        };
451
452        Ok(GenerateContentRequest {
453            contents,
454            tools,
455            tool_config,
456            system_instruction: request
457                .system_prompt
458                .as_ref()
459                .map(|text| SystemInstruction::new(text.clone())),
460            generation_config: if generation_config.is_empty() {
461                None
462            } else {
463                Some(Value::Object(generation_config))
464            },
465            reasoning_config: None,
466        })
467    }
468
469    fn convert_from_gemini_response(
470        response: GenerateContentResponse,
471    ) -> Result<LLMResponse, LLMError> {
472        let mut candidates = response.candidates.into_iter();
473        let candidate = candidates.next().ok_or_else(|| {
474            let formatted_error =
475                error_display::format_llm_error("Gemini", "No candidate in response");
476            LLMError::Provider(formatted_error)
477        })?;
478
479        if candidate.content.parts.is_empty() {
480            return Ok(LLMResponse {
481                content: Some(String::new()),
482                tool_calls: None,
483                usage: None,
484                finish_reason: FinishReason::Stop,
485                reasoning: None,
486            });
487        }
488
489        let mut text_content = String::new();
490        let mut tool_calls = Vec::new();
491
492        for part in candidate.content.parts {
493            match part {
494                Part::Text { text } => {
495                    text_content.push_str(&text);
496                }
497                Part::FunctionCall { function_call } => {
498                    let call_id = function_call.id.clone().unwrap_or_else(|| {
499                        format!(
500                            "call_{}_{}",
501                            std::time::SystemTime::now()
502                                .duration_since(std::time::UNIX_EPOCH)
503                                .unwrap_or_default()
504                                .as_nanos(),
505                            tool_calls.len()
506                        )
507                    });
508                    tool_calls.push(ToolCall {
509                        id: call_id,
510                        call_type: "function".to_string(),
511                        function: FunctionCall {
512                            name: function_call.name,
513                            arguments: serde_json::to_string(&function_call.args)
514                                .unwrap_or_else(|_| "{}".to_string()),
515                        },
516                    });
517                }
518                Part::FunctionResponse { .. } => {
519                    // Ignore echoed tool responses to avoid duplicating tool output
520                }
521            }
522        }
523
524        let finish_reason = match candidate.finish_reason.as_deref() {
525            Some("STOP") => FinishReason::Stop,
526            Some("MAX_TOKENS") => FinishReason::Length,
527            Some("SAFETY") => FinishReason::ContentFilter,
528            Some("FUNCTION_CALL") => FinishReason::ToolCalls,
529            Some(other) => FinishReason::Error(other.to_string()),
530            None => FinishReason::Stop,
531        };
532
533        Ok(LLMResponse {
534            content: if text_content.is_empty() {
535                None
536            } else {
537                Some(text_content)
538            },
539            tool_calls: if tool_calls.is_empty() {
540                None
541            } else {
542                Some(tool_calls)
543            },
544            usage: None,
545            finish_reason,
546            reasoning: None,
547        })
548    }
549
550    fn convert_from_streaming_response(
551        response: StreamingResponse,
552    ) -> Result<LLMResponse, LLMError> {
553        let converted_candidates: Vec<Candidate> = response
554            .candidates
555            .into_iter()
556            .map(|candidate| Candidate {
557                content: candidate.content,
558                finish_reason: candidate.finish_reason,
559            })
560            .collect();
561
562        let converted = GenerateContentResponse {
563            candidates: converted_candidates,
564            prompt_feedback: None,
565            usage_metadata: response.usage_metadata,
566        };
567
568        Self::convert_from_gemini_response(converted)
569    }
570
571    fn map_streaming_error(error: StreamingError) -> LLMError {
572        match error {
573            StreamingError::NetworkError { message, .. } => {
574                let formatted = error_display::format_llm_error(
575                    "Gemini",
576                    &format!("Network error: {}", message),
577                );
578                LLMError::Network(formatted)
579            }
580            StreamingError::ApiError {
581                status_code,
582                message,
583                ..
584            } => {
585                if status_code == 401 || status_code == 403 {
586                    let formatted = error_display::format_llm_error(
587                        "Gemini",
588                        &format!("HTTP {}: {}", status_code, message),
589                    );
590                    LLMError::Authentication(formatted)
591                } else if status_code == 429 {
592                    LLMError::RateLimit
593                } else {
594                    let formatted = error_display::format_llm_error(
595                        "Gemini",
596                        &format!("API error ({}): {}", status_code, message),
597                    );
598                    LLMError::Provider(formatted)
599                }
600            }
601            StreamingError::ParseError { message, .. } => {
602                let formatted =
603                    error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
604                LLMError::Provider(formatted)
605            }
606            StreamingError::TimeoutError {
607                operation,
608                duration,
609            } => {
610                let formatted = error_display::format_llm_error(
611                    "Gemini",
612                    &format!(
613                        "Streaming timeout during {} after {:?}",
614                        operation, duration
615                    ),
616                );
617                LLMError::Network(formatted)
618            }
619            StreamingError::ContentError { message } => {
620                let formatted = error_display::format_llm_error(
621                    "Gemini",
622                    &format!("Content error: {}", message),
623                );
624                LLMError::Provider(formatted)
625            }
626            StreamingError::StreamingError { message, .. } => {
627                let formatted = error_display::format_llm_error(
628                    "Gemini",
629                    &format!("Streaming error: {}", message),
630                );
631                LLMError::Provider(formatted)
632            }
633        }
634    }
635}
636
637fn sanitize_function_parameters(parameters: Value) -> Value {
638    match parameters {
639        Value::Object(map) => {
640            let mut sanitized = Map::new();
641            for (key, value) in map {
642                if key == "additionalProperties" {
643                    continue;
644                }
645                sanitized.insert(key, sanitize_function_parameters(value));
646            }
647            Value::Object(sanitized)
648        }
649        Value::Array(values) => Value::Array(
650            values
651                .into_iter()
652                .map(sanitize_function_parameters)
653                .collect(),
654        ),
655        other => other,
656    }
657}
658
659#[async_trait]
660impl LLMClient for GeminiProvider {
661    async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
662        // Check if the prompt is a serialized GenerateContentRequest
663        let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
664            // Try to parse as JSON GenerateContentRequest
665            match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
666                Ok(gemini_request) => {
667                    // Convert GenerateContentRequest to LLMRequest
668                    let mut messages = Vec::new();
669                    let mut system_prompt = None;
670
671                    // Convert contents to messages
672                    for content in &gemini_request.contents {
673                        let role = match content.role.as_str() {
674                            crate::config::constants::message_roles::USER => MessageRole::User,
675                            "model" => MessageRole::Assistant,
676                            crate::config::constants::message_roles::SYSTEM => {
677                                // Extract system message
678                                let text = content
679                                    .parts
680                                    .iter()
681                                    .filter_map(|part| part.as_text())
682                                    .collect::<Vec<_>>()
683                                    .join("");
684                                system_prompt = Some(text);
685                                continue;
686                            }
687                            _ => MessageRole::User, // Default to user
688                        };
689
690                        let content_text = content
691                            .parts
692                            .iter()
693                            .filter_map(|part| part.as_text())
694                            .collect::<Vec<_>>()
695                            .join("");
696
697                        messages.push(Message {
698                            role,
699                            content: content_text,
700                            tool_calls: None,
701                            tool_call_id: None,
702                        });
703                    }
704
705                    // Convert tools if present
706                    let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
707                        gemini_tools
708                            .iter()
709                            .flat_map(|tool| &tool.function_declarations)
710                            .map(|decl| crate::llm::provider::ToolDefinition {
711                                tool_type: "function".to_string(),
712                                function: crate::llm::provider::FunctionDefinition {
713                                    name: decl.name.clone(),
714                                    description: decl.description.clone(),
715                                    parameters: decl.parameters.clone(),
716                                },
717                            })
718                            .collect::<Vec<_>>()
719                    });
720
721                    let llm_request = LLMRequest {
722                        messages,
723                        system_prompt,
724                        tools,
725                        model: self.model.clone(),
726                        max_tokens: gemini_request
727                            .generation_config
728                            .as_ref()
729                            .and_then(|config| config.get("maxOutputTokens"))
730                            .and_then(|v| v.as_u64())
731                            .map(|v| v as u32),
732                        temperature: gemini_request
733                            .generation_config
734                            .as_ref()
735                            .and_then(|config| config.get("temperature"))
736                            .and_then(|v| v.as_f64())
737                            .map(|v| v as f32),
738                        stream: false,
739                        tool_choice: None,
740                        parallel_tool_calls: None,
741                        parallel_tool_config: None,
742                        reasoning_effort: None,
743                    };
744
745                    // Use the standard LLMProvider generate method
746                    let response = LLMProvider::generate(self, llm_request).await?;
747
748                    // If there are tool calls, include them in the response content as JSON
749                    let content = if let Some(tool_calls) = &response.tool_calls {
750                        if !tool_calls.is_empty() {
751                            // Create a JSON structure that the agent can parse
752                            let tool_call_json = json!({
753                                "tool_calls": tool_calls.iter().map(|tc| {
754                                    json!({
755                                        "function": {
756                                            "name": tc.function.name,
757                                            "arguments": tc.function.arguments
758                                        }
759                                    })
760                                }).collect::<Vec<_>>()
761                            });
762                            tool_call_json.to_string()
763                        } else {
764                            response.content.unwrap_or("".to_string())
765                        }
766                    } else {
767                        response.content.unwrap_or("".to_string())
768                    };
769
770                    return Ok(llm_types::LLMResponse {
771                        content,
772                        model: self.model.clone(),
773                        usage: response.usage.map(|u| llm_types::Usage {
774                            prompt_tokens: u.prompt_tokens as usize,
775                            completion_tokens: u.completion_tokens as usize,
776                            total_tokens: u.total_tokens as usize,
777                            cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
778                            cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
779                            cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
780                        }),
781                        reasoning: response.reasoning,
782                    });
783                }
784                Err(_) => {
785                    // Fallback: treat as regular prompt
786                    LLMRequest {
787                        messages: vec![Message {
788                            role: MessageRole::User,
789                            content: prompt.to_string(),
790                            tool_calls: None,
791                            tool_call_id: None,
792                        }],
793                        system_prompt: None,
794                        tools: None,
795                        model: self.model.clone(),
796                        max_tokens: None,
797                        temperature: None,
798                        stream: false,
799                        tool_choice: None,
800                        parallel_tool_calls: None,
801                        parallel_tool_config: None,
802                        reasoning_effort: None,
803                    }
804                }
805            }
806        } else {
807            // Fallback: treat as regular prompt
808            LLMRequest {
809                messages: vec![Message {
810                    role: MessageRole::User,
811                    content: prompt.to_string(),
812                    tool_calls: None,
813                    tool_call_id: None,
814                }],
815                system_prompt: None,
816                tools: None,
817                model: self.model.clone(),
818                max_tokens: None,
819                temperature: None,
820                stream: false,
821                tool_choice: None,
822                parallel_tool_calls: None,
823                parallel_tool_config: None,
824                reasoning_effort: None,
825            }
826        };
827
828        let response = LLMProvider::generate(self, request).await?;
829
830        Ok(llm_types::LLMResponse {
831            content: response.content.unwrap_or("".to_string()),
832            model: self.model.clone(),
833            usage: response.usage.map(|u| llm_types::Usage {
834                prompt_tokens: u.prompt_tokens as usize,
835                completion_tokens: u.completion_tokens as usize,
836                total_tokens: u.total_tokens as usize,
837                cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
838                cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
839                cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
840            }),
841            reasoning: response.reasoning,
842        })
843    }
844
845    fn backend_kind(&self) -> llm_types::BackendKind {
846        llm_types::BackendKind::Gemini
847    }
848
849    fn model_id(&self) -> &str {
850        &self.model
851    }
852}
853
854#[cfg(test)]
855mod tests {
856    use super::*;
857    use crate::config::constants::models;
858    use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
859
860    #[test]
861    fn convert_to_gemini_request_maps_history_and_system_prompt() {
862        let provider = GeminiProvider::new("test-key".to_string());
863        let mut assistant_message = Message::assistant("Sure thing".to_string());
864        assistant_message.tool_calls = Some(vec![ToolCall::function(
865            "call_1".to_string(),
866            "list_files".to_string(),
867            json!({ "path": "." }).to_string(),
868        )]);
869
870        let tool_response =
871            Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
872
873        let tool_def = ToolDefinition::function(
874            "list_files".to_string(),
875            "List files".to_string(),
876            json!({
877                "type": "object",
878                "properties": {
879                    "path": { "type": "string" }
880                }
881            }),
882        );
883
884        let request = LLMRequest {
885            messages: vec![
886                Message::user("hello".to_string()),
887                assistant_message,
888                tool_response,
889            ],
890            system_prompt: Some("System prompt".to_string()),
891            tools: Some(vec![tool_def]),
892            model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
893            max_tokens: Some(256),
894            temperature: Some(0.4),
895            stream: false,
896            tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
897                tool_type: "function".to_string(),
898                function: SpecificFunctionChoice {
899                    name: "list_files".to_string(),
900                },
901            })),
902            parallel_tool_calls: None,
903            parallel_tool_config: None,
904            reasoning_effort: None,
905        };
906
907        let gemini_request = provider
908            .convert_to_gemini_request(&request)
909            .expect("conversion should succeed");
910
911        let system_instruction = gemini_request
912            .system_instruction
913            .expect("system instruction should be present");
914        assert!(matches!(
915            system_instruction.parts.as_slice(),
916            [Part::Text { text }] if text == "System prompt"
917        ));
918
919        assert_eq!(gemini_request.contents.len(), 3);
920        assert_eq!(gemini_request.contents[0].role, "user");
921        assert!(
922            gemini_request.contents[1]
923                .parts
924                .iter()
925                .any(|part| matches!(part, Part::FunctionCall { .. }))
926        );
927        let tool_part = gemini_request.contents[2]
928            .parts
929            .iter()
930            .find_map(|part| match part {
931                Part::FunctionResponse { function_response } => Some(function_response),
932                _ => None,
933            })
934            .expect("tool response part should exist");
935        assert_eq!(tool_part.name, "list_files");
936    }
937
938    #[test]
939    fn convert_from_gemini_response_extracts_tool_calls() {
940        let response = GenerateContentResponse {
941            candidates: vec![crate::gemini::Candidate {
942                content: Content {
943                    role: "model".to_string(),
944                    parts: vec![
945                        Part::Text {
946                            text: "Here you go".to_string(),
947                        },
948                        Part::FunctionCall {
949                            function_call: GeminiFunctionCall {
950                                name: "list_files".to_string(),
951                                args: json!({ "path": "." }),
952                                id: Some("call_1".to_string()),
953                            },
954                        },
955                    ],
956                },
957                finish_reason: Some("FUNCTION_CALL".to_string()),
958            }],
959            prompt_feedback: None,
960            usage_metadata: None,
961        };
962
963        let llm_response = GeminiProvider::convert_from_gemini_response(response)
964            .expect("conversion should succeed");
965
966        assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
967        let calls = llm_response
968            .tool_calls
969            .expect("tool call should be present");
970        assert_eq!(calls.len(), 1);
971        assert_eq!(calls[0].function.name, "list_files");
972        assert!(calls[0].function.arguments.contains("path"));
973        assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
974    }
975
976    #[test]
977    fn sanitize_function_parameters_removes_additional_properties() {
978        let parameters = json!({
979            "type": "object",
980            "properties": {
981                "input": {
982                    "type": "object",
983                    "properties": {
984                        "path": { "type": "string" }
985                    },
986                    "additionalProperties": false
987                }
988            },
989            "additionalProperties": false
990        });
991
992        let sanitized = sanitize_function_parameters(parameters);
993        let root = sanitized
994            .as_object()
995            .expect("root parameters should remain an object");
996        assert!(!root.contains_key("additionalProperties"));
997
998        let nested = root
999            .get("properties")
1000            .and_then(|value| value.as_object())
1001            .and_then(|props| props.get("input"))
1002            .and_then(|value| value.as_object())
1003            .expect("nested object should be preserved");
1004        assert!(!nested.contains_key("additionalProperties"));
1005    }
1006}