1use crate::config::constants::{models, urls};
2use crate::config::core::{GeminiPromptCacheMode, GeminiPromptCacheSettings, PromptCachingConfig};
3use crate::gemini::function_calling::{
4 FunctionCall as GeminiFunctionCall, FunctionCallingConfig, FunctionResponse,
5};
6use crate::gemini::models::SystemInstruction;
7use crate::gemini::streaming::{
8 StreamingCandidate, StreamingError, StreamingProcessor, StreamingResponse,
9};
10use crate::gemini::{
11 Candidate, Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, Part,
12 Tool, ToolConfig,
13};
14use crate::llm::client::LLMClient;
15use crate::llm::error_display;
16use crate::llm::provider::{
17 FinishReason, FunctionCall, LLMError, LLMProvider, LLMRequest, LLMResponse, LLMStream,
18 LLMStreamEvent, Message, MessageRole, ToolCall, ToolChoice,
19};
20use crate::llm::types as llm_types;
21use async_stream::try_stream;
22use async_trait::async_trait;
23use reqwest::Client as HttpClient;
24use serde_json::{Map, Value, json};
25use std::collections::HashMap;
26use tokio::sync::mpsc;
27
28pub struct GeminiProvider {
29 api_key: String,
30 http_client: HttpClient,
31 base_url: String,
32 model: String,
33 prompt_cache_enabled: bool,
34 prompt_cache_settings: GeminiPromptCacheSettings,
35}
36
37impl GeminiProvider {
38 pub fn new(api_key: String) -> Self {
39 Self::with_model_internal(api_key, models::GEMINI_2_5_FLASH_PREVIEW.to_string(), None)
40 }
41
42 pub fn with_model(api_key: String, model: String) -> Self {
43 Self::with_model_internal(api_key, model, None)
44 }
45
46 pub fn from_config(
47 api_key: Option<String>,
48 model: Option<String>,
49 base_url: Option<String>,
50 prompt_cache: Option<PromptCachingConfig>,
51 ) -> Self {
52 let api_key_value = api_key.unwrap_or_default();
53 let mut provider = if let Some(model_value) = model {
54 Self::with_model_internal(api_key_value, model_value, prompt_cache)
55 } else {
56 Self::with_model_internal(
57 api_key_value,
58 models::GEMINI_2_5_FLASH_PREVIEW.to_string(),
59 prompt_cache,
60 )
61 };
62 if let Some(base) = base_url {
63 provider.base_url = base;
64 }
65 provider
66 }
67
68 fn with_model_internal(
69 api_key: String,
70 model: String,
71 prompt_cache: Option<PromptCachingConfig>,
72 ) -> Self {
73 let (prompt_cache_enabled, prompt_cache_settings) =
74 Self::extract_prompt_cache_settings(prompt_cache);
75
76 Self {
77 api_key,
78 http_client: HttpClient::new(),
79 base_url: urls::GEMINI_API_BASE.to_string(),
80 model,
81 prompt_cache_enabled,
82 prompt_cache_settings,
83 }
84 }
85
86 fn extract_prompt_cache_settings(
87 prompt_cache: Option<PromptCachingConfig>,
88 ) -> (bool, GeminiPromptCacheSettings) {
89 if let Some(cfg) = prompt_cache {
90 let provider_settings = cfg.providers.gemini;
91 let enabled = cfg.enabled
92 && provider_settings.enabled
93 && provider_settings.mode != GeminiPromptCacheMode::Off;
94 (enabled, provider_settings)
95 } else {
96 (false, GeminiPromptCacheSettings::default())
97 }
98 }
99}
100
101#[async_trait]
102impl LLMProvider for GeminiProvider {
103 fn name(&self) -> &str {
104 "gemini"
105 }
106
107 fn supports_streaming(&self) -> bool {
108 true
109 }
110
111 fn supports_reasoning(&self, _model: &str) -> bool {
112 false
113 }
114
115 async fn generate(&self, request: LLMRequest) -> Result<LLMResponse, LLMError> {
116 let gemini_request = self.convert_to_gemini_request(&request)?;
117
118 let url = format!(
119 "{}/models/{}:generateContent?key={}",
120 self.base_url, request.model, self.api_key
121 );
122
123 let response = self
124 .http_client
125 .post(&url)
126 .json(&gemini_request)
127 .send()
128 .await
129 .map_err(|e| {
130 let formatted_error =
131 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
132 LLMError::Network(formatted_error)
133 })?;
134
135 if !response.status().is_success() {
136 let status = response.status();
137 let error_text = response.text().await.unwrap_or_default();
138
139 if status.as_u16() == 429
141 || error_text.contains("insufficient_quota")
142 || error_text.contains("quota")
143 || error_text.contains("rate limit")
144 {
145 return Err(LLMError::RateLimit);
146 }
147
148 let formatted_error = error_display::format_llm_error(
149 "Gemini",
150 &format!("HTTP {}: {}", status, error_text),
151 );
152 return Err(LLMError::Provider(formatted_error));
153 }
154
155 let gemini_response: GenerateContentResponse = response.json().await.map_err(|e| {
156 let formatted_error = error_display::format_llm_error(
157 "Gemini",
158 &format!("Failed to parse response: {}", e),
159 );
160 LLMError::Provider(formatted_error)
161 })?;
162
163 Self::convert_from_gemini_response(gemini_response)
164 }
165
166 async fn stream(&self, request: LLMRequest) -> Result<LLMStream, LLMError> {
167 let gemini_request = self.convert_to_gemini_request(&request)?;
168
169 let url = format!(
170 "{}/models/{}:streamGenerateContent?key={}",
171 self.base_url, request.model, self.api_key
172 );
173
174 let response = self
175 .http_client
176 .post(&url)
177 .json(&gemini_request)
178 .send()
179 .await
180 .map_err(|e| {
181 let formatted_error =
182 error_display::format_llm_error("Gemini", &format!("Network error: {}", e));
183 LLMError::Network(formatted_error)
184 })?;
185
186 if !response.status().is_success() {
187 let status = response.status();
188 let error_text = response.text().await.unwrap_or_default();
189
190 if status.as_u16() == 401 || status.as_u16() == 403 {
191 let formatted_error = error_display::format_llm_error(
192 "Gemini",
193 &format!("HTTP {}: {}", status, error_text),
194 );
195 return Err(LLMError::Authentication(formatted_error));
196 }
197
198 if status.as_u16() == 429
199 || error_text.contains("insufficient_quota")
200 || error_text.contains("quota")
201 || error_text.contains("rate limit")
202 {
203 return Err(LLMError::RateLimit);
204 }
205
206 let formatted_error = error_display::format_llm_error(
207 "Gemini",
208 &format!("HTTP {}: {}", status, error_text),
209 );
210 return Err(LLMError::Provider(formatted_error));
211 }
212
213 let (event_tx, event_rx) = mpsc::unbounded_channel::<Result<LLMStreamEvent, LLMError>>();
214 let completion_sender = event_tx.clone();
215
216 tokio::spawn(async move {
217 let mut processor = StreamingProcessor::new();
218 let token_sender = completion_sender.clone();
219 let mut aggregated_text = String::new();
220 let mut on_chunk = |chunk: &str| -> Result<(), StreamingError> {
221 if chunk.is_empty() {
222 return Ok(());
223 }
224
225 aggregated_text.push_str(chunk);
226
227 token_sender
228 .send(Ok(LLMStreamEvent::Token {
229 delta: chunk.to_string(),
230 }))
231 .map_err(|_| StreamingError::StreamingError {
232 message: "Streaming consumer dropped".to_string(),
233 partial_content: Some(chunk.to_string()),
234 })?;
235 Ok(())
236 };
237
238 let result = processor.process_stream(response, &mut on_chunk).await;
239 match result {
240 Ok(mut streaming_response) => {
241 if streaming_response.candidates.is_empty()
242 && !aggregated_text.trim().is_empty()
243 {
244 streaming_response.candidates.push(StreamingCandidate {
245 content: Content {
246 role: "model".to_string(),
247 parts: vec![Part::Text {
248 text: aggregated_text.clone(),
249 }],
250 },
251 finish_reason: None,
252 index: Some(0),
253 });
254 }
255
256 match Self::convert_from_streaming_response(streaming_response) {
257 Ok(final_response) => {
258 let _ = completion_sender.send(Ok(LLMStreamEvent::Completed {
259 response: final_response,
260 }));
261 }
262 Err(err) => {
263 let _ = completion_sender.send(Err(err));
264 }
265 }
266 }
267 Err(error) => {
268 let mapped = Self::map_streaming_error(error);
269 let _ = completion_sender.send(Err(mapped));
270 }
271 }
272 });
273
274 drop(event_tx);
275
276 let stream = {
277 let mut receiver = event_rx;
278 try_stream! {
279 while let Some(event) = receiver.recv().await {
280 yield event?;
281 }
282 }
283 };
284
285 Ok(Box::pin(stream))
286 }
287
288 fn supported_models(&self) -> Vec<String> {
289 vec![
290 models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
291 models::google::GEMINI_2_5_PRO.to_string(),
292 ]
293 }
294
295 fn validate_request(&self, request: &LLMRequest) -> Result<(), LLMError> {
296 if !self.supported_models().contains(&request.model) {
297 let formatted_error = error_display::format_llm_error(
298 "Gemini",
299 &format!("Unsupported model: {}", request.model),
300 );
301 return Err(LLMError::InvalidRequest(formatted_error));
302 }
303 Ok(())
304 }
305}
306
307impl GeminiProvider {
308 fn convert_to_gemini_request(
309 &self,
310 request: &LLMRequest,
311 ) -> Result<GenerateContentRequest, LLMError> {
312 if self.prompt_cache_enabled
313 && matches!(
314 self.prompt_cache_settings.mode,
315 GeminiPromptCacheMode::Explicit
316 )
317 {
318 }
322
323 let mut call_map: HashMap<String, String> = HashMap::new();
324 for message in &request.messages {
325 if message.role == MessageRole::Assistant
326 && let Some(tool_calls) = &message.tool_calls
327 {
328 for tool_call in tool_calls {
329 call_map.insert(tool_call.id.clone(), tool_call.function.name.clone());
330 }
331 }
332 }
333
334 let mut contents: Vec<Content> = Vec::new();
335 for message in &request.messages {
336 if message.role == MessageRole::System {
337 continue;
338 }
339
340 let mut parts: Vec<Part> = Vec::new();
341 if message.role != MessageRole::Tool && !message.content.is_empty() {
342 parts.push(Part::Text {
343 text: message.content.clone(),
344 });
345 }
346
347 if message.role == MessageRole::Assistant
348 && let Some(tool_calls) = &message.tool_calls
349 {
350 for tool_call in tool_calls {
351 let parsed_args = serde_json::from_str(&tool_call.function.arguments)
352 .unwrap_or_else(|_| json!({}));
353 parts.push(Part::FunctionCall {
354 function_call: GeminiFunctionCall {
355 name: tool_call.function.name.clone(),
356 args: parsed_args,
357 id: Some(tool_call.id.clone()),
358 },
359 });
360 }
361 }
362
363 if message.role == MessageRole::Tool {
364 if let Some(tool_call_id) = &message.tool_call_id {
365 let func_name = call_map
366 .get(tool_call_id)
367 .cloned()
368 .unwrap_or_else(|| tool_call_id.clone());
369 let response_text = serde_json::from_str::<Value>(&message.content)
370 .map(|value| {
371 serde_json::to_string_pretty(&value)
372 .unwrap_or_else(|_| message.content.clone())
373 })
374 .unwrap_or_else(|_| message.content.clone());
375
376 let response_payload = json!({
377 "name": func_name.clone(),
378 "content": [{
379 "text": response_text
380 }]
381 });
382
383 parts.push(Part::FunctionResponse {
384 function_response: FunctionResponse {
385 name: func_name,
386 response: response_payload,
387 },
388 });
389 } else if !message.content.is_empty() {
390 parts.push(Part::Text {
391 text: message.content.clone(),
392 });
393 }
394 }
395
396 if !parts.is_empty() {
397 contents.push(Content {
398 role: message.role.as_gemini_str().to_string(),
399 parts,
400 });
401 }
402 }
403
404 let tools: Option<Vec<Tool>> = request.tools.as_ref().map(|definitions| {
405 definitions
406 .iter()
407 .map(|tool| Tool {
408 function_declarations: vec![FunctionDeclaration {
409 name: tool.function.name.clone(),
410 description: tool.function.description.clone(),
411 parameters: tool.function.parameters.clone(),
412 }],
413 })
414 .collect()
415 });
416
417 let mut generation_config = Map::new();
418 if let Some(max_tokens) = request.max_tokens {
419 generation_config.insert("maxOutputTokens".to_string(), json!(max_tokens));
420 }
421 if let Some(temp) = request.temperature {
422 generation_config.insert("temperature".to_string(), json!(temp));
423 }
424 let has_tools = request
425 .tools
426 .as_ref()
427 .map(|defs| !defs.is_empty())
428 .unwrap_or(false);
429 let tool_config = if has_tools || request.tool_choice.is_some() {
430 Some(match request.tool_choice.as_ref() {
431 Some(ToolChoice::None) => ToolConfig {
432 function_calling_config: FunctionCallingConfig::none(),
433 },
434 Some(ToolChoice::Any) => ToolConfig {
435 function_calling_config: FunctionCallingConfig::any(),
436 },
437 Some(ToolChoice::Specific(spec)) => {
438 let mut config = FunctionCallingConfig::any();
439 if spec.tool_type == "function" {
440 config.allowed_function_names = Some(vec![spec.function.name.clone()]);
441 }
442 ToolConfig {
443 function_calling_config: config,
444 }
445 }
446 _ => ToolConfig::auto(),
447 })
448 } else {
449 None
450 };
451
452 Ok(GenerateContentRequest {
453 contents,
454 tools,
455 tool_config,
456 system_instruction: request
457 .system_prompt
458 .as_ref()
459 .map(|text| SystemInstruction::new(text.clone())),
460 generation_config: if generation_config.is_empty() {
461 None
462 } else {
463 Some(Value::Object(generation_config))
464 },
465 reasoning_config: None,
466 })
467 }
468
469 fn convert_from_gemini_response(
470 response: GenerateContentResponse,
471 ) -> Result<LLMResponse, LLMError> {
472 let mut candidates = response.candidates.into_iter();
473 let candidate = candidates.next().ok_or_else(|| {
474 let formatted_error =
475 error_display::format_llm_error("Gemini", "No candidate in response");
476 LLMError::Provider(formatted_error)
477 })?;
478
479 if candidate.content.parts.is_empty() {
480 return Ok(LLMResponse {
481 content: Some(String::new()),
482 tool_calls: None,
483 usage: None,
484 finish_reason: FinishReason::Stop,
485 reasoning: None,
486 });
487 }
488
489 let mut text_content = String::new();
490 let mut tool_calls = Vec::new();
491
492 for part in candidate.content.parts {
493 match part {
494 Part::Text { text } => {
495 text_content.push_str(&text);
496 }
497 Part::FunctionCall { function_call } => {
498 let call_id = function_call.id.clone().unwrap_or_else(|| {
499 format!(
500 "call_{}_{}",
501 std::time::SystemTime::now()
502 .duration_since(std::time::UNIX_EPOCH)
503 .unwrap_or_default()
504 .as_nanos(),
505 tool_calls.len()
506 )
507 });
508 tool_calls.push(ToolCall {
509 id: call_id,
510 call_type: "function".to_string(),
511 function: FunctionCall {
512 name: function_call.name,
513 arguments: serde_json::to_string(&function_call.args)
514 .unwrap_or_else(|_| "{}".to_string()),
515 },
516 });
517 }
518 Part::FunctionResponse { .. } => {
519 }
521 }
522 }
523
524 let finish_reason = match candidate.finish_reason.as_deref() {
525 Some("STOP") => FinishReason::Stop,
526 Some("MAX_TOKENS") => FinishReason::Length,
527 Some("SAFETY") => FinishReason::ContentFilter,
528 Some("FUNCTION_CALL") => FinishReason::ToolCalls,
529 Some(other) => FinishReason::Error(other.to_string()),
530 None => FinishReason::Stop,
531 };
532
533 Ok(LLMResponse {
534 content: if text_content.is_empty() {
535 None
536 } else {
537 Some(text_content)
538 },
539 tool_calls: if tool_calls.is_empty() {
540 None
541 } else {
542 Some(tool_calls)
543 },
544 usage: None,
545 finish_reason,
546 reasoning: None,
547 })
548 }
549
550 fn convert_from_streaming_response(
551 response: StreamingResponse,
552 ) -> Result<LLMResponse, LLMError> {
553 let converted_candidates: Vec<Candidate> = response
554 .candidates
555 .into_iter()
556 .map(|candidate| Candidate {
557 content: candidate.content,
558 finish_reason: candidate.finish_reason,
559 })
560 .collect();
561
562 let converted = GenerateContentResponse {
563 candidates: converted_candidates,
564 prompt_feedback: None,
565 usage_metadata: response.usage_metadata,
566 };
567
568 Self::convert_from_gemini_response(converted)
569 }
570
571 fn map_streaming_error(error: StreamingError) -> LLMError {
572 match error {
573 StreamingError::NetworkError { message, .. } => {
574 let formatted = error_display::format_llm_error(
575 "Gemini",
576 &format!("Network error: {}", message),
577 );
578 LLMError::Network(formatted)
579 }
580 StreamingError::ApiError {
581 status_code,
582 message,
583 ..
584 } => {
585 if status_code == 401 || status_code == 403 {
586 let formatted = error_display::format_llm_error(
587 "Gemini",
588 &format!("HTTP {}: {}", status_code, message),
589 );
590 LLMError::Authentication(formatted)
591 } else if status_code == 429 {
592 LLMError::RateLimit
593 } else {
594 let formatted = error_display::format_llm_error(
595 "Gemini",
596 &format!("API error ({}): {}", status_code, message),
597 );
598 LLMError::Provider(formatted)
599 }
600 }
601 StreamingError::ParseError { message, .. } => {
602 let formatted =
603 error_display::format_llm_error("Gemini", &format!("Parse error: {}", message));
604 LLMError::Provider(formatted)
605 }
606 StreamingError::TimeoutError {
607 operation,
608 duration,
609 } => {
610 let formatted = error_display::format_llm_error(
611 "Gemini",
612 &format!(
613 "Streaming timeout during {} after {:?}",
614 operation, duration
615 ),
616 );
617 LLMError::Network(formatted)
618 }
619 StreamingError::ContentError { message } => {
620 let formatted = error_display::format_llm_error(
621 "Gemini",
622 &format!("Content error: {}", message),
623 );
624 LLMError::Provider(formatted)
625 }
626 StreamingError::StreamingError { message, .. } => {
627 let formatted = error_display::format_llm_error(
628 "Gemini",
629 &format!("Streaming error: {}", message),
630 );
631 LLMError::Provider(formatted)
632 }
633 }
634 }
635}
636
637#[async_trait]
638impl LLMClient for GeminiProvider {
639 async fn generate(&mut self, prompt: &str) -> Result<llm_types::LLMResponse, LLMError> {
640 let request = if prompt.starts_with('{') && prompt.contains("\"contents\"") {
642 match serde_json::from_str::<crate::gemini::GenerateContentRequest>(prompt) {
644 Ok(gemini_request) => {
645 let mut messages = Vec::new();
647 let mut system_prompt = None;
648
649 for content in &gemini_request.contents {
651 let role = match content.role.as_str() {
652 crate::config::constants::message_roles::USER => MessageRole::User,
653 "model" => MessageRole::Assistant,
654 crate::config::constants::message_roles::SYSTEM => {
655 let text = content
657 .parts
658 .iter()
659 .filter_map(|part| part.as_text())
660 .collect::<Vec<_>>()
661 .join("");
662 system_prompt = Some(text);
663 continue;
664 }
665 _ => MessageRole::User, };
667
668 let content_text = content
669 .parts
670 .iter()
671 .filter_map(|part| part.as_text())
672 .collect::<Vec<_>>()
673 .join("");
674
675 messages.push(Message {
676 role,
677 content: content_text,
678 tool_calls: None,
679 tool_call_id: None,
680 });
681 }
682
683 let tools = gemini_request.tools.as_ref().map(|gemini_tools| {
685 gemini_tools
686 .iter()
687 .flat_map(|tool| &tool.function_declarations)
688 .map(|decl| crate::llm::provider::ToolDefinition {
689 tool_type: "function".to_string(),
690 function: crate::llm::provider::FunctionDefinition {
691 name: decl.name.clone(),
692 description: decl.description.clone(),
693 parameters: decl.parameters.clone(),
694 },
695 })
696 .collect::<Vec<_>>()
697 });
698
699 let llm_request = LLMRequest {
700 messages,
701 system_prompt,
702 tools,
703 model: self.model.clone(),
704 max_tokens: gemini_request
705 .generation_config
706 .as_ref()
707 .and_then(|config| config.get("maxOutputTokens"))
708 .and_then(|v| v.as_u64())
709 .map(|v| v as u32),
710 temperature: gemini_request
711 .generation_config
712 .as_ref()
713 .and_then(|config| config.get("temperature"))
714 .and_then(|v| v.as_f64())
715 .map(|v| v as f32),
716 stream: false,
717 tool_choice: None,
718 parallel_tool_calls: None,
719 parallel_tool_config: None,
720 reasoning_effort: None,
721 };
722
723 let response = LLMProvider::generate(self, llm_request).await?;
725
726 let content = if let Some(tool_calls) = &response.tool_calls {
728 if !tool_calls.is_empty() {
729 let tool_call_json = json!({
731 "tool_calls": tool_calls.iter().map(|tc| {
732 json!({
733 "function": {
734 "name": tc.function.name,
735 "arguments": tc.function.arguments
736 }
737 })
738 }).collect::<Vec<_>>()
739 });
740 tool_call_json.to_string()
741 } else {
742 response.content.unwrap_or("".to_string())
743 }
744 } else {
745 response.content.unwrap_or("".to_string())
746 };
747
748 return Ok(llm_types::LLMResponse {
749 content,
750 model: self.model.clone(),
751 usage: response.usage.map(|u| llm_types::Usage {
752 prompt_tokens: u.prompt_tokens as usize,
753 completion_tokens: u.completion_tokens as usize,
754 total_tokens: u.total_tokens as usize,
755 cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
756 cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
757 cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
758 }),
759 reasoning: response.reasoning,
760 });
761 }
762 Err(_) => {
763 LLMRequest {
765 messages: vec![Message {
766 role: MessageRole::User,
767 content: prompt.to_string(),
768 tool_calls: None,
769 tool_call_id: None,
770 }],
771 system_prompt: None,
772 tools: None,
773 model: self.model.clone(),
774 max_tokens: None,
775 temperature: None,
776 stream: false,
777 tool_choice: None,
778 parallel_tool_calls: None,
779 parallel_tool_config: None,
780 reasoning_effort: None,
781 }
782 }
783 }
784 } else {
785 LLMRequest {
787 messages: vec![Message {
788 role: MessageRole::User,
789 content: prompt.to_string(),
790 tool_calls: None,
791 tool_call_id: None,
792 }],
793 system_prompt: None,
794 tools: None,
795 model: self.model.clone(),
796 max_tokens: None,
797 temperature: None,
798 stream: false,
799 tool_choice: None,
800 parallel_tool_calls: None,
801 parallel_tool_config: None,
802 reasoning_effort: None,
803 }
804 };
805
806 let response = LLMProvider::generate(self, request).await?;
807
808 Ok(llm_types::LLMResponse {
809 content: response.content.unwrap_or("".to_string()),
810 model: self.model.clone(),
811 usage: response.usage.map(|u| llm_types::Usage {
812 prompt_tokens: u.prompt_tokens as usize,
813 completion_tokens: u.completion_tokens as usize,
814 total_tokens: u.total_tokens as usize,
815 cached_prompt_tokens: u.cached_prompt_tokens.map(|v| v as usize),
816 cache_creation_tokens: u.cache_creation_tokens.map(|v| v as usize),
817 cache_read_tokens: u.cache_read_tokens.map(|v| v as usize),
818 }),
819 reasoning: response.reasoning,
820 })
821 }
822
823 fn backend_kind(&self) -> llm_types::BackendKind {
824 llm_types::BackendKind::Gemini
825 }
826
827 fn model_id(&self) -> &str {
828 &self.model
829 }
830}
831
832#[cfg(test)]
833mod tests {
834 use super::*;
835 use crate::config::constants::models;
836 use crate::llm::provider::{SpecificFunctionChoice, SpecificToolChoice, ToolDefinition};
837
838 #[test]
839 fn convert_to_gemini_request_maps_history_and_system_prompt() {
840 let provider = GeminiProvider::new("test-key".to_string());
841 let mut assistant_message = Message::assistant("Sure thing".to_string());
842 assistant_message.tool_calls = Some(vec![ToolCall::function(
843 "call_1".to_string(),
844 "list_files".to_string(),
845 json!({ "path": "." }).to_string(),
846 )]);
847
848 let tool_response =
849 Message::tool_response("call_1".to_string(), json!({ "result": "ok" }).to_string());
850
851 let tool_def = ToolDefinition::function(
852 "list_files".to_string(),
853 "List files".to_string(),
854 json!({
855 "type": "object",
856 "properties": {
857 "path": { "type": "string" }
858 }
859 }),
860 );
861
862 let request = LLMRequest {
863 messages: vec![
864 Message::user("hello".to_string()),
865 assistant_message,
866 tool_response,
867 ],
868 system_prompt: Some("System prompt".to_string()),
869 tools: Some(vec![tool_def]),
870 model: models::google::GEMINI_2_5_FLASH_PREVIEW.to_string(),
871 max_tokens: Some(256),
872 temperature: Some(0.4),
873 stream: false,
874 tool_choice: Some(ToolChoice::Specific(SpecificToolChoice {
875 tool_type: "function".to_string(),
876 function: SpecificFunctionChoice {
877 name: "list_files".to_string(),
878 },
879 })),
880 parallel_tool_calls: None,
881 parallel_tool_config: None,
882 reasoning_effort: None,
883 };
884
885 let gemini_request = provider
886 .convert_to_gemini_request(&request)
887 .expect("conversion should succeed");
888
889 let system_instruction = gemini_request
890 .system_instruction
891 .expect("system instruction should be present");
892 assert!(matches!(
893 system_instruction.parts.as_slice(),
894 [Part::Text { text }] if text == "System prompt"
895 ));
896
897 assert_eq!(gemini_request.contents.len(), 3);
898 assert_eq!(gemini_request.contents[0].role, "user");
899 assert!(
900 gemini_request.contents[1]
901 .parts
902 .iter()
903 .any(|part| matches!(part, Part::FunctionCall { .. }))
904 );
905 let tool_part = gemini_request.contents[2]
906 .parts
907 .iter()
908 .find_map(|part| match part {
909 Part::FunctionResponse { function_response } => Some(function_response),
910 _ => None,
911 })
912 .expect("tool response part should exist");
913 assert_eq!(tool_part.name, "list_files");
914 }
915
916 #[test]
917 fn convert_from_gemini_response_extracts_tool_calls() {
918 let response = GenerateContentResponse {
919 candidates: vec![crate::gemini::Candidate {
920 content: Content {
921 role: "model".to_string(),
922 parts: vec![
923 Part::Text {
924 text: "Here you go".to_string(),
925 },
926 Part::FunctionCall {
927 function_call: GeminiFunctionCall {
928 name: "list_files".to_string(),
929 args: json!({ "path": "." }),
930 id: Some("call_1".to_string()),
931 },
932 },
933 ],
934 },
935 finish_reason: Some("FUNCTION_CALL".to_string()),
936 }],
937 prompt_feedback: None,
938 usage_metadata: None,
939 };
940
941 let llm_response = GeminiProvider::convert_from_gemini_response(response)
942 .expect("conversion should succeed");
943
944 assert_eq!(llm_response.content.as_deref(), Some("Here you go"));
945 let calls = llm_response
946 .tool_calls
947 .expect("tool call should be present");
948 assert_eq!(calls.len(), 1);
949 assert_eq!(calls[0].function.name, "list_files");
950 assert!(calls[0].function.arguments.contains("path"));
951 assert_eq!(llm_response.finish_reason, FinishReason::ToolCalls);
952 }
953}