1use crate::config::constants::{models, urls};
2use crate::config::core::{GeminiPromptCacheMode, GeminiPromptCacheSettings, PromptCachingConfig};
3use crate::gemini::function_calling::{
4 FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
5};
6use crate::gemini::models::SystemInstruction;
7use crate::gemini::streaming::{
8 StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
9};
10use crate::gemini::{
11 Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
12 Tool, ToolConfig,
13};
14use crate::llm::client::LLMClient;
15use crate::llm::error_display;
16use crate::llm::provider::{
17 FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
18 LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
19};
20use crate::llm::types as llm_types;
21use async_stream::try_stream;
22use async_trait::async_trait;
23use reqwest::Client as HttpClient;
24use serde_json::{Map, Value, json};
25use std::collections::HashMap;
26use tokio::sync::mpsc;
27
28pub struct GeminiProvider {
29 api_key: String,
30 http_client: HttpClient,
31 base_url: String,
32 model: String,
33 prompt_cache_enabled: bool,
34 prompt_cache_settings: GeminiPromptCacheSettings,
35}
36
37impl GeminiProvider {
38 pub fn new(api_key: String) -> Self {
39 Self::with_model_internal(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string(), None)
40 }
41
42 pub fn with_model(api_key: String, model: String) -> Self {
43 Self::with_model_internal(api_key, model, None)
44 }
45
46 pub fn from_config(
47 api_key: Option<String>,
48 model: Option<String>,
49 base_url: Option<String>,
50 prompt_cache: Option<PromptCachingConfig>,
51 ) -> Self {
52 let api_key_value = api_key.unwrap_or_default();
53 let mut provider = if let Some(model_value) = model {
54 Self::with_model_internal(api_key_value, model_value, prompt_cache)
55 } else {
56 Self::with_model_internal(
57 api_key_value,
58 models::GEMINI_2_5_FLASH_PREVIEW.to_string(),
59 prompt_cache,
60 )
61 };
62 if let Some(base) = base_url {
63 provider.base_url = base;
64 }
65 provider
66 }
67
68 fn with_model_internal(
69 api_key: String,
70 model: String,
71 prompt_cache: Option<PromptCachingConfig>,
72 ) -> Self {
73 let (prompt_cache_enabled, prompt_cache_settings) =
74 Self::extract_prompt_cache_settings(prompt_cache);
75
76 Self {
77 api_key,
78 http_client: HttpClient::new(),
79 base_url: urls::GEMINI_API_BASE.to_string(),
80 model,
81 prompt_cache_enabled,
82 prompt_cache_settings,
83 }
84 }
85
86 fn extract_prompt_cache_settings(
87 prompt_cache: Option<PromptCachingConfig>,
88 ) -> (bool, GeminiPromptCacheSettings) {
89 if let Some(cfg) = prompt_cache {
90 let provider_settings = cfg.providers.gemini;
91 let enabled = cfg.enabled
92 && provider_settings.enabled
93 && provider_settings.mode != GeminiPromptCacheMode::Off;
94 (enabled, provider_settings)
95 } else {
96 (false, GeminiPromptCacheSettings::default())
97 }
98 }
99}
100
101#[async_trait]
102impl LLMProvider for GeminiProvider {
103 fn name(&self) -> &str {
104 "gemini"
105 }
106
107 fn supports_streaming(&self) -> bool {
108 true
109 }
110
111 fn supports_reasoning(&self, _model: &str) -> bool {
112 false
113 }
114
115 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
116 let gemini_request = self.convert_to_gemini_request(&request)?;
117
118 let url = format!(
119 "{}/models/{}:generateContent?key={}",
120 self.base_url, request.model, self.api_key
121 );
122
123 let response = self
124 .http_client
125 .post(&url)
126 .json(&gemini_request)
127 .send()
128 .await
129 .map_err(|e| {
130 let formatted_error =
131 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
132 LLMError::Network(formatted_error)
133 })?;
134
135 if !response.status().is_success() {
136 let status = response.status();
137 let error_text = response.text().await.unwrap_or_default();
138
139 if status.as_u16() == 429
141 || error_text.contains("insufficient_quota")
142 || error_text.contains("quota")
143 || error_text.contains("rate limit")
144 {
145 return Err(LLMError::RateLimit);
146 }
147
148 let formatted_error = error_display::format_llm_error(
149 "Gemini",
150 &format!("HTTP {}: {}", status, error_text),
151 );
152 return Err(LLMError::Provider(formatted_error));
153 }
154
155 let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
156 let formatted_error = error_display::format_llm_error(
157 "Gemini",
158 &format!("Failed to parse response: {}", e),
159 );
160 LLMError::Provider(formatted_error)
161 })?;
162
163 Self::convert_from_gemini_response(gemini_response)
164 }
165
166 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
167 let gemini_request = self.convert_to_gemini_request(&request)?;
168
169 let url = format!(
170 "{}/models/{}:streamGenerateContent?key={}",
171 self.base_url, request.model, self.api_key
172 );
173
174 let response = self
175 .http_client
176 .post(&url)
177 .json(&gemini_request)
178 .send()
179 .await
180 .map_err(|e| {
181 let formatted_error =
182 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
183 LLMError::Network(formatted_error)
184 })?;
185
186 if !response.status().is_success() {
187 let status = response.status();
188 let error_text = response.text().await.unwrap_or_default();
189
190 if status.as_u16() == 401 || status.as_u16() == 403 {
191 let formatted_error = error_display::format_llm_error(
192 "Gemini",
193 &format!("HTTP {}: {}", status, error_text),
194 );
195 return Err(LLMError::Authentication(formatted_error));
196 }
197
198 if status.as_u16() == 429
199 || error_text.contains("insufficient_quota")
200 || error_text.contains("quota")
201 || error_text.contains("rate limit")
202 {
203 return Err(LLMError::RateLimit);
204 }
205
206 let formatted_error = error_display::format_llm_error(
207 "Gemini",
208 &format!("HTTP {}: {}", status, error_text),
209 );
210 return Err(LLMError::Provider(formatted_error));
211 }
212
213 let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
214 let completion_sender = event_tx.clone();
215
216 tokio::spawn(async move {
217 let mut processor = StreamingProcessor::new();
218 let token_sender = completion_sender.clone();
219 let mut aggregated_text = String::new();
220 let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
221 if chunk.is_empty() {
222 return Ok(());
223 }
224
225 aggregated_text.push_str(chunk);
226
227 token_sender
228 .send(Ok(LLMStreamEvent::Token {
229 delta: chunk.to_string(),
230 }))
231 .map_err(|_| StreamingError::StreamingError {
232 message: "Streaming consumer dropped".to_string(),
233 partial_content: Some(chunk.to_string()),
234 })?;
235 Ok(())
236 };
237
238 let result = processor.process_stream(response, &mut on_chunk).await;
239 match result {
240 Ok(mut streaming_response) => {
241 if streaming_response.candidates.is_empty()
242 && !aggregated_text.trim().is_empty()
243 {
244 streaming_response.candidates.push(StreamingCandidate {
245 content: Content {
246 role: "model".to_string(),
247 parts: vec![Part::Text {
248 text: aggregated_text.clone(),
249 }],
250 },
251 finish_reason: None,
252 index: Some(0),
253 });
254 }
255
256 match Self::convert_from_streaming_response(streaming_response) {
257 Ok(final_response) => {
258 let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
259 response: final_response,
260 }));
261 }
262 Err(err) => {
263 let _ = completion_sender.send(Err(err));
264 }
265 }
266 }
267 Err(error) => {
268 let mapped = Self::map_streaming_error(error);
269 let _ = completion_sender.send(Err(mapped));
270 }
271 }
272 });
273
274 drop(event_tx);
275
276 let stream = {
277 let mut receiver = event_rx;
278 try_stream! {
279 while let Some(event) = receiver.recv().await {
280 yield event?;
281 }
282 }
283 };
284
285 Ok(Box::pin(stream))
286 }
287
288 fn supported_models(&self) -> Vec<String> {
289 vec![
290 models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
291 models::google::GEMINI_2_5_PRO.to_string(),
292 ]
293 }
294
295 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
296 if !self.supported_models().contains(&request.model) {
297 let formatted_error = error_display::format_llm_error(
298 "Gemini",
299 &format!("Unsupported model: {}", request.model),
300 );
301 return Err(LLMError::InvalidRequest(formatted_error));
302 }
303 Ok(())
304 }
305}
306
307impl GeminiProvider {
308 fn convert_to_gemini_request(
309 &self,
310 request: &LLMRequest,
311 ) -> Result<GenerateContentRequest, LLMError> {
312 if self.prompt_cache_enabled
313 && matches!(
314 self.prompt_cache_settings.mode,
315 GeminiPromptCacheMode::Explicit
316 )
317 {
318 }
322
323 let mut call_map: HashMap<String, String> = HashMap::new();
324 for message in &request.messages {
325 if message.role == MessageRole::Assistant
326 && let Some(tool_calls) = &message.tool_calls
327 {
328 for tool_call in tool_calls {
329 call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
330 }
331 }
332 }
333
334 let mut contents: Vec<Content> = Vec::new();
335 for message in &request.messages {
336 if message.role == MessageRole::System {
337 continue;
338 }
339
340 let mut parts: Vec<Part> = Vec::new();
341 if message.role != MessageRole::Tool && !message.content.is_empty() {
342 parts.push(Part::Text {
343 text: message.content.clone(),
344 });
345 }
346
347 if message.role == MessageRole::Assistant
348 && let Some(tool_calls) = &message.tool_calls
349 {
350 for tool_call in tool_calls {
351 let parsed_args = serde_json::from_str(&tool_call.function.arguments)
352 .unwrap_or_else(|_| json!({}));
353 parts.push(Part::FunctionCall {
354 function_call: GeminiFunctionCall {
355 name: tool_call.function.name.clone(),
356 args: parsed_args,
357 id: Some(tool_call.id.clone()),
358 },
359 });
360 }
361 }
362
363 if message.role == MessageRole::Tool {
364 if let Some(tool_call_id) = &message.tool_call_id {
365 let func_name = call_map
366 .get(tool_call_id)
367 .cloned()
368 .unwrap_or_else(|| tool_call_id.clone());
369 let response_text = serde_json::from_str::<Value>(&message.content)
370 .map(|value| {
371 serde_json::to_string_pretty(&value)
372 .unwrap_or_else(|_| message.content.clone())
373 })
374 .unwrap_or_else(|_| message.content.clone());
375
376 let response_payload = json!({
377 "name": func_name.clone(),
378 "content": [{
379 "text": response_text
380 }]
381 });
382
383 parts.push(Part::FunctionResponse {
384 function_response: FunctionResponse {
385 name: func_name,
386 response: response_payload,
387 },
388 });
389 } else if !message.content.is_empty() {
390 parts.push(Part::Text {
391 text: message.content.clone(),
392 });
393 }
394 }
395
396 if !parts.is_empty() {
397 contents.push(Content {
398 role: message.role.as_gemini_str().to_string(),
399 parts,
400 });
401 }
402 }
403
404 let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
405 definitions
406 .iter()
407 .map(|tool| Tool {
408 function_declarations: vec![FunctionDeclaration {
409 name: tool.function.name.clone(),
410 description: tool.function.description.clone(),
411 parameters: sanitize_function_parameters(tool.function.parameters.clone()),
412 }],
413 })
414 .collect()
415 });
416
417 let mut generation_config = Map::new();
418 if let Some(max_tokens) = request.max_tokens {
419 generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
420 }
421 if let Some(temp) = request.temperature {
422 generation_config.insert("temperature".to_string(), json!(temp));
423 }
424 let has_tools = request
425 .tools
426 .as_ref()
427 .map(|defs| !defs.is_empty())
428 .unwrap_or(false);
429 let tool_config = if has_tools || request.tool_choice.is_some() {
430 Some(match request.tool_choice.as_ref() {
431 Some(ToolChoice::None) => ToolConfig {
432 function_calling_config: FunctionCallingConfig::none(),
433 },
434 Some(ToolChoice::Any) => ToolConfig {
435 function_calling_config: FunctionCallingConfig::any(),
436 },
437 Some(ToolChoice::Specific(spec)) => {
438 let mut config = FunctionCallingConfig::any();
439 if spec.tool_type == "function" {
440 config.allowed_function_names = Some(vec![spec.function.name.clone()]);
441 }
442 ToolConfig {
443 function_calling_config: config,
444 }
445 }
446 _ => ToolConfig::auto(),
447 })
448 } else {
449 None
450 };
451
452 Ok(GenerateContentRequest {
453 contents,
454 tools,
455 tool_config,
456 system_instruction: request
457 .system_prompt
458 .as_ref()
459 .map(|text| SystemInstruction::new(text.clone())),
460 generation_config: if generation_config.is_empty() {
461 None
462 } else {
463 Some(Value::Object(generation_config))
464 },
465 reasoning_config: None,
466 })
467 }
468
469 fn convert_from_gemini_response(
470 response: GenerateContentResponse,
471 ) -> Result<LLMResponse, LLMError> {
472 let mut candidates = response.candidates.into_iter();
473 let candidate = candidates.next().ok_or_else(|| {
474 let formatted_error =
475 error_display::format_llm_error("Gemini", "No candidate in response");
476 LLMError::Provider(formatted_error)
477 })?;
478
479 if candidate.content.parts.is_empty() {
480 return Ok(LLMResponse {
481 content: Some(String::new()),
482 tool_calls: None,
483 usage: None,
484 finish_reason: FinishReason::Stop,
485 reasoning: None,
486 });
487 }
488
489 let mut text_content = String::new();
490 let mut tool_calls = Vec::new();
491
492 for part in candidate.content.parts {
493 match part {
494 Part::Text { text } => {
495 text_content.push_str(&text);
496 }
497 Part::FunctionCall { function_call } => {
498 let call_id = function_call.id.clone().unwrap_or_else(|| {
499 format!(
500 "call_{}_{}",
501 std::time::SystemTime::now()
502 .duration_since(std::time::UNIX_EPOCH)
503 .unwrap_or_default()
504 .as_nanos(),
505 tool_calls.len()
506 )
507 });
508 tool_calls.push(ToolCall {
509 id: call_id,
510 call_type: "function".to_string(),
511 function: FunctionCall {
512 name: function_call.name,
513 arguments: serde_json::to_string(&function_call.args)
514 .unwrap_or_else(|_| "{}".to_string()),
515 },
516 });
517 }
518 Part::FunctionResponse { .. } => {
519 }
521 }
522 }
523
524 let finish_reason = match candidate.finish_reason.as_deref() {
525 Some("STOP") => FinishReason::Stop,
526 Some("MAX_TOKENS") => FinishReason::Length,
527 Some("SAFETY") => FinishReason::ContentFilter,
528 Some("FUNCTION_CALL") => FinishReason::ToolCalls,
529 Some(other) => FinishReason::Error(other.to_string()),
530 None => FinishReason::Stop,
531 };
532
533 Ok(LLMResponse {
534 content: if text_content.is_empty() {
535 None
536 } else {
537 Some(text_content)
538 },
539 tool_calls: if tool_calls.is_empty() {
540 None
541 } else {
542 Some(tool_calls)
543 },
544 usage: None,
545 finish_reason,
546 reasoning: None,
547 })
548 }
549
550 fn convert_from_streaming_response(
551 response: StreamingResponse,
552 ) -> Result<LLMResponse, LLMError> {
553 let converted_candidates: Vec<Candidate> = response
554 .candidates
555 .into_iter()
556 .map(|candidate| Candidate {
557 content: candidate.content,
558 finish_reason: candidate.finish_reason,
559 })
560 .collect();
561
562 let converted = GenerateContentResponse {
563 candidates: converted_candidates,
564 prompt_feedback: None,
565 usage_metadata: response.usage_metadata,
566 };
567
568 Self::convert_from_gemini_response(converted)
569 }
570
571 fn map_streaming_error(error: StreamingError) -> LLMError {
572 match error {
573 StreamingError::NetworkError { message, .. } => {
574 let formatted = error_display::format_llm_error(
575 "Gemini",
576 &format!("Network error: {}", message),
577 );
578 LLMError::Network(formatted)
579 }
580 StreamingError::ApiError {
581 status_code,
582 message,
583 ..
584 } => {
585 if status_code == 401 || status_code == 403 {
586 let formatted = error_display::format_llm_error(
587 "Gemini",
588 &format!("HTTP {}: {}", status_code, message),
589 );
590 LLMError::Authentication(formatted)
591 } else if status_code == 429 {
592 LLMError::RateLimit
593 } else {
594 let formatted = error_display::format_llm_error(
595 "Gemini",
596 &format!("API error ({}): {}", status_code, message),
597 );
598 LLMError::Provider(formatted)
599 }
600 }
601 StreamingError::ParseError { message, .. } => {
602 let formatted =
603 error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
604 LLMError::Provider(formatted)
605 }
606 StreamingError::TimeoutError {
607 operation,
608 duration,
609 } => {
610 let formatted = error_display::format_llm_error(
611 "Gemini",
612 &format!(
613 "Streaming timeout during {} after {:?}",
614 operation, duration
615 ),
616 );
617 LLMError::Network(formatted)
618 }
619 StreamingError::ContentError { message } => {
620 let formatted = error_display::format_llm_error(
621 "Gemini",
622 &format!("Content error: {}", message),
623 );
624 LLMError::Provider(formatted)
625 }
626 StreamingError::StreamingError { message, .. } => {
627 let formatted = error_display::format_llm_error(
628 "Gemini",
629 &format!("Streaming error: {}", message),
630 );
631 LLMError::Provider(formatted)
632 }
633 }
634 }
635}
636
637fn sanitize_function_parameters(parameters: Value) -> Value {
638 match parameters {
639 Value::Object(map) => {
640 let mut sanitized = Map::new();
641 for (key, value) in map {
642 if key == "additionalProperties" {
643 continue;
644 }
645 sanitized.insert(key, sanitize_function_parameters(value));
646 }
647 Value::Object(sanitized)
648 }
649 Value::Array(values) => Value::Array(
650 values
651 .into_iter()
652 .map(sanitize_function_parameters)
653 .collect(),
654 ),
655 other => other,
656 }
657}
658
659#[async_trait]
660impl LLMClient for GeminiProvider {
661 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
662 let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
664 match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
666 Ok(gemini_request) => {
667 let mut messages = Vec::new();
669 let mut system_prompt = None;
670
671 for content in &gemini_request.contents {
673 let role = match content.role.as_str() {
674 crate::config::constants::message_roles::USER => MessageRole::User,
675 "model" => MessageRole::Assistant,
676 crate::config::constants::message_roles::SYSTEM => {
677 let text = content
679 .parts
680 .iter()
681 .filter_map(|part| part.as_text())
682 .collect::<Vec<_>>()
683 .join("");
684 system_prompt = Some(text);
685 continue;
686 }
687 _ => MessageRole::User, };
689
690 let content_text = content
691 .parts
692 .iter()
693 .filter_map(|part| part.as_text())
694 .collect::<Vec<_>>()
695 .join("");
696
697 messages.push(Message {
698 role,
699 content: content_text,
700 tool_calls: None,
701 tool_call_id: None,
702 });
703 }
704
705 let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
707 gemini_tools
708 .iter()
709 .flat_map(|tool| &tool.function_declarations)
710 .map(|decl| crate::llm::provider::ToolDefinition {
711 tool_type: "function".to_string(),
712 function: crate::llm::provider::FunctionDefinition {
713 name: decl.name.clone(),
714 description: decl.description.clone(),
715 parameters: decl.parameters.clone(),
716 },
717 })
718 .collect::<Vec<_>>()
719 });
720
721 let llm_request = LLMRequest {
722 messages,
723 system_prompt,
724 tools,
725 model: self.model.clone(),
726 max_tokens: gemini_request
727 .generation_config
728 .as_ref()
729 .and_then(|config| config.get("maxOutputTokens"))
730 .and_then(|v| v.as_u64())
731 .map(|v| v as u32),
732 temperature: gemini_request
733 .generation_config
734 .as_ref()
735 .and_then(|config| config.get("temperature"))
736 .and_then(|v| v.as_f64())
737 .map(|v| v as f32),
738 stream: false,
739 tool_choice: None,
740 parallel_tool_calls: None,
741 parallel_tool_config: None,
742 reasoning_effort: None,
743 };
744
745 let response = LLMProvider::generate(self, llm_request).await?;
747
748 let content = if let Some(tool_calls) = &response.tool_calls {
750 if !tool_calls.is_empty() {
751 let tool_call_json = json!({
753 "tool_calls": tool_calls.iter().map(|tc| {
754 json!({
755 "function": {
756 "name": tc.function.name,
757 "arguments": tc.function.arguments
758 }
759 })
760 }).collect::<Vec<_>>()
761 });
762 tool_call_json.to_string()
763 } else {
764 response.content.unwrap_or("".to_string())
765 }
766 } else {
767 response.content.unwrap_or("".to_string())
768 };
769
770 return Ok(llm_types::LLMResponse {
771 content,
772 model: self.model.clone(),
773 usage: response.usage.map(|u| llm_types::Usage {
774 prompt_tokens: u.prompt_tokens as usize,
775 completion_tokens: u.completion_tokens as usize,
776 total_tokens: u.total_tokens as usize,
777 cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
778 cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
779 cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
780 }),
781 reasoning: response.reasoning,
782 });
783 }
784 Err(_) => {
785 LLMRequest {
787 messages: vec![Message {
788 role: MessageRole::User,
789 content: prompt.to_string(),
790 tool_calls: None,
791 tool_call_id: None,
792 }],
793 system_prompt: None,
794 tools: None,
795 model: self.model.clone(),
796 max_tokens: None,
797 temperature: None,
798 stream: false,
799 tool_choice: None,
800 parallel_tool_calls: None,
801 parallel_tool_config: None,
802 reasoning_effort: None,
803 }
804 }
805 }
806 } else {
807 LLMRequest {
809 messages: vec![Message {
810 role: MessageRole::User,
811 content: prompt.to_string(),
812 tool_calls: None,
813 tool_call_id: None,
814 }],
815 system_prompt: None,
816 tools: None,
817 model: self.model.clone(),
818 max_tokens: None,
819 temperature: None,
820 stream: false,
821 tool_choice: None,
822 parallel_tool_calls: None,
823 parallel_tool_config: None,
824 reasoning_effort: None,
825 }
826 };
827
828 let response = LLMProvider::generate(self, request).await?;
829
830 Ok(llm_types::LLMResponse {
831 content: response.content.unwrap_or("".to_string()),
832 model: self.model.clone(),
833 usage: response.usage.map(|u| llm_types::Usage {
834 prompt_tokens: u.prompt_tokens as usize,
835 completion_tokens: u.completion_tokens as usize,
836 total_tokens: u.total_tokens as usize,
837 cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
838 cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
839 cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
840 }),
841 reasoning: response.reasoning,
842 })
843 }
844
845 fn backend_kind(&self) -> llm_types::BackendKind {
846 llm_types::BackendKind::Gemini
847 }
848
849 fn model_id(&self) -> &str {
850 &self.model
851 }
852}
853
854#[cfg(test)]
855mod tests {
856 use super::*;
857 use crate::config::constants::models;
858 use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
859
860 #[test]
861 fn convert_to_gemini_request_maps_history_and_system_prompt() {
862 let provider = GeminiProvider::new("test-key".to_string());
863 let mut assistant_message = Message::assistant("Sure thing".to_string());
864 assistant_message.tool_calls = Some(vec![ToolCall::function(
865 "call_1".to_string(),
866 "list_files".to_string(),
867 json!({ "path": "." }).to_string(),
868 )]);
869
870 let tool_response =
871 Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
872
873 let tool_def = ToolDefinition::function(
874 "list_files".to_string(),
875 "List files".to_string(),
876 json!({
877 "type": "object",
878 "properties": {
879 "path": { "type": "string" }
880 }
881 }),
882 );
883
884 let request = LLMRequest {
885 messages: vec![
886 Message::user("hello".to_string()),
887 assistant_message,
888 tool_response,
889 ],
890 system_prompt: Some("System prompt".to_string()),
891 tools: Some(vec![tool_def]),
892 model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
893 max_tokens: Some(256),
894 temperature: Some(0.4),
895 stream: false,
896 tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
897 tool_type: "function".to_string(),
898 function: SpecificFunctionChoice {
899 name: "list_files".to_string(),
900 },
901 })),
902 parallel_tool_calls: None,
903 parallel_tool_config: None,
904 reasoning_effort: None,
905 };
906
907 let gemini_request = provider
908 .convert_to_gemini_request(&request)
909 .expect("conversion should succeed");
910
911 let system_instruction = gemini_request
912 .system_instruction
913 .expect("system instruction should be present");
914 assert!(matches!(
915 system_instruction.parts.as_slice(),
916 [Part::Text { text }] if text == "System prompt"
917 ));
918
919 assert_eq!(gemini_request.contents.len(), 3);
920 assert_eq!(gemini_request.contents[0].role, "user");
921 assert!(
922 gemini_request.contents[1]
923 .parts
924 .iter()
925 .any(|part| matches!(part, Part::FunctionCall { .. }))
926 );
927 let tool_part = gemini_request.contents[2]
928 .parts
929 .iter()
930 .find_map(|part| match part {
931 Part::FunctionResponse { function_response } => Some(function_response),
932 _ => None,
933 })
934 .expect("tool response part should exist");
935 assert_eq!(tool_part.name, "list_files");
936 }
937
938 #[test]
939 fn convert_from_gemini_response_extracts_tool_calls() {
940 let response = GenerateContentResponse {
941 candidates: vec![crate::gemini::Candidate {
942 content: Content {
943 role: "model".to_string(),
944 parts: vec![
945 Part::Text {
946 text: "Here you go".to_string(),
947 },
948 Part::FunctionCall {
949 function_call: GeminiFunctionCall {
950 name: "list_files".to_string(),
951 args: json!({ "path": "." }),
952 id: Some("call_1".to_string()),
953 },
954 },
955 ],
956 },
957 finish_reason: Some("FUNCTION_CALL".to_string()),
958 }],
959 prompt_feedback: None,
960 usage_metadata: None,
961 };
962
963 let llm_response = GeminiProvider::convert_from_gemini_response(response)
964 .expect("conversion should succeed");
965
966 assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
967 let calls = llm_response
968 .tool_calls
969 .expect("tool call should be present");
970 assert_eq!(calls.len(), 1);
971 assert_eq!(calls[0].function.name, "list_files");
972 assert!(calls[0].function.arguments.contains("path"));
973 assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
974 }
975
976 #[test]
977 fn sanitize_function_parameters_removes_additional_properties() {
978 let parameters = json!({
979 "type": "object",
980 "properties": {
981 "input": {
982 "type": "object",
983 "properties": {
984 "path": { "type": "string" }
985 },
986 "additionalProperties": false
987 }
988 },
989 "additionalProperties": false
990 });
991
992 let sanitized = sanitize_function_parameters(parameters);
993 let root = sanitized
994 .as_object()
995 .expect("root parameters should remain an object");
996 assert!(!root.contains_key("additionalProperties"));
997
998 let nested = root
999 .get("properties")
1000 .and_then(|value| value.as_object())
1001 .and_then(|props| props.get("input"))
1002 .and_then(|value| value.as_object())
1003 .expect("nested object should be preserved");
1004 assert!(!nested.contains_key("additionalProperties"));
1005 }
1006}