1use async_trait::async_trait;
2use reqwest::{Client as HttpClient, StatusCode};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error, info, warn};
7
8use crate::api::error::ApiError;
9use crate::api::provider::{CompletionResponse, Provider};
10use crate::app::conversation::{
11 AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
12};
13use crate::config::model::{ModelId, ModelParameters};
14use steer_tools::ToolSchema;
15
16const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
17
18#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
20 #[serde(rename = "mimeType")]
21 mime_type: String,
22 data: String, }
24
25#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
27 #[serde(rename = "mimeType")]
28 mime_type: String,
29 #[serde(rename = "fileUri")]
30 file_uri: String,
31}
32
33#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiCodeExecutionResult {
35 outcome: String, }
38
39pub struct GeminiClient {
40 api_key: String,
41 client: HttpClient,
42}
43
44impl GeminiClient {
45 pub fn new(api_key: impl Into<String>) -> Self {
46 Self {
47 api_key: api_key.into(),
48 client: HttpClient::new(),
49 }
50 }
51}
52
53#[derive(Debug, Serialize)]
54struct GeminiRequest {
55 contents: Vec<GeminiContent>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 #[serde(rename = "systemInstruction")]
58 system_instruction: Option<GeminiSystemInstruction>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 tools: Option<Vec<GeminiTool>>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 #[serde(rename = "generationConfig")]
63 generation_config: Option<GeminiGenerationConfig>,
64}
65
66#[derive(Debug, Serialize, Default, Clone)]
67struct GeminiGenerationConfig {
68 #[serde(skip_serializing_if = "Option::is_none")]
69 #[serde(rename = "stopSequences")]
70 stop_sequences: Option<Vec<String>>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 #[serde(rename = "responseMimeType")]
73 response_mime_type: Option<GeminiMimeType>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 #[serde(rename = "candidateCount")]
76 candidate_count: Option<i32>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 #[serde(rename = "maxOutputTokens")]
79 max_output_tokens: Option<i32>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 temperature: Option<f32>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 #[serde(rename = "topP")]
84 top_p: Option<f32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 #[serde(rename = "topK")]
87 top_k: Option<i32>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 #[serde(rename = "thinkingConfig")]
90 thinking_config: Option<GeminiThinkingConfig>,
91}
92
93#[derive(Debug, Serialize, Default, Clone)]
94struct GeminiThinkingConfig {
95 #[serde(skip_serializing_if = "Option::is_none")]
96 #[serde(rename = "includeThoughts")]
97 include_thoughts: Option<bool>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 #[serde(rename = "thinkingBudget")]
100 thinking_budget: Option<i32>,
101}
102
103#[allow(dead_code)]
104#[derive(Debug, Serialize, Clone)]
105#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
106enum GeminiMimeType {
107 MimeTypeUnspecified,
108 TextPlain,
109 ApplicationJson,
110}
111
112#[derive(Debug, Serialize)]
113struct GeminiSystemInstruction {
114 parts: Vec<GeminiRequestPart>,
115}
116
117#[derive(Debug, Serialize)]
118struct GeminiContent {
119 role: String,
120 parts: Vec<GeminiRequestPart>,
121}
122
123#[derive(Debug, Serialize)]
125#[serde(untagged)]
126enum GeminiRequestPart {
127 Text {
128 text: String,
129 },
130 #[serde(rename = "functionCall")]
131 FunctionCall {
132 #[serde(rename = "functionCall")]
133 function_call: GeminiFunctionCall, },
135 #[serde(rename = "functionResponse")]
136 FunctionResponse {
137 #[serde(rename = "functionResponse")]
138 function_response: GeminiFunctionResponse, },
140}
141
142#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum GeminiResponsePartData {
146 Text {
147 text: String,
148 },
149 #[serde(rename = "inlineData")]
150 InlineData {
151 #[serde(rename = "inlineData")]
152 inline_data: GeminiBlob,
153 },
154 #[serde(rename = "functionCall")]
155 FunctionCall {
156 #[serde(rename = "functionCall")]
157 function_call: GeminiFunctionCall,
158 },
159 #[serde(rename = "fileData")]
160 FileData {
161 #[serde(rename = "fileData")]
162 file_data: GeminiFileData,
163 },
164 #[serde(rename = "executableCode")]
165 ExecutableCode {
166 #[serde(rename = "executableCode")]
167 executable_code: GeminiExecutableCode,
168 },
169 }
171
172#[derive(Debug, Deserialize)]
174struct GeminiResponsePart {
175 #[serde(default)] thought: bool,
177
178 #[serde(flatten)] data: GeminiResponsePartData,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183struct GeminiFunctionCall {
184 name: String,
185 args: Value,
186}
187
188#[derive(Debug, Serialize, PartialEq)]
189struct GeminiTool {
190 #[serde(rename = "functionDeclarations")]
191 function_declarations: Vec<GeminiFunctionDeclaration>,
192}
193
194#[derive(Debug, Serialize, PartialEq)]
195struct GeminiFunctionDeclaration {
196 name: String,
197 description: String,
198 parameters: GeminiParameterSchema,
199}
200
201#[derive(Debug, Serialize, PartialEq)]
202struct GeminiParameterSchema {
203 #[serde(rename = "type")]
204 schema_type: String, properties: serde_json::Map<String, Value>,
206 required: Vec<String>,
207}
208
209#[derive(Debug, Deserialize)]
210struct GeminiResponse {
211 #[serde(rename = "candidates")]
212 #[serde(skip_serializing_if = "Option::is_none")]
213 candidates: Option<Vec<GeminiCandidate>>,
214 #[serde(rename = "promptFeedback")]
215 #[serde(skip_serializing_if = "Option::is_none")]
216 prompt_feedback: Option<GeminiPromptFeedback>,
217 #[serde(rename = "usageMetadata")]
218 #[serde(skip_serializing_if = "Option::is_none")]
219 usage_metadata: Option<GeminiUsageMetadata>,
220}
221
222#[derive(Debug, Deserialize)]
223struct GeminiCandidate {
224 content: GeminiContentResponse,
225 #[serde(rename = "finishReason")]
226 #[serde(skip_serializing_if = "Option::is_none")]
227 finish_reason: Option<GeminiFinishReason>,
228 #[serde(rename = "safetyRatings")]
229 #[serde(skip_serializing_if = "Option::is_none")]
230 safety_ratings: Option<Vec<GeminiSafetyRating>>,
231 #[serde(rename = "citationMetadata")]
232 #[serde(skip_serializing_if = "Option::is_none")]
233 citation_metadata: Option<GeminiCitationMetadata>,
234}
235
236#[derive(Debug, Deserialize)]
237#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
238enum GeminiFinishReason {
239 FinishReasonUnspecified,
240 Stop,
241 MaxTokens,
242 Safety,
243 Recitation,
244 Other,
245 #[serde(rename = "TOOL_CODE_ERROR")]
246 ToolCodeError,
247 #[serde(rename = "TOOL_EXECUTION_HALT")]
248 ToolExecutionHalt,
249 MalformedFunctionCall,
250}
251
252#[derive(Debug, Deserialize)]
253struct GeminiPromptFeedback {
254 #[serde(rename = "blockReason")]
255 #[serde(skip_serializing_if = "Option::is_none")]
256 block_reason: Option<GeminiBlockReason>,
257 #[serde(rename = "safetyRatings")]
258 #[serde(skip_serializing_if = "Option::is_none")]
259 safety_ratings: Option<Vec<GeminiSafetyRating>>,
260}
261
262#[derive(Debug, Deserialize)]
263#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
264enum GeminiBlockReason {
265 BlockReasonUnspecified,
266 Safety,
267 Other,
268}
269
270#[derive(Debug, Deserialize)]
271#[allow(dead_code)]
272struct GeminiSafetyRating {
273 category: GeminiHarmCategory,
274 probability: GeminiHarmProbability,
275 #[serde(default)] blocked: bool,
277}
278
279#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
281#[allow(clippy::enum_variant_names)]
282enum GeminiHarmCategory {
283 HarmCategoryUnspecified,
284 HarmCategoryDerogatory,
285 HarmCategoryToxicity,
286 HarmCategoryViolence,
287 HarmCategorySexual,
288 HarmCategoryMedical,
289 HarmCategoryDangerous,
290 HarmCategoryHarassment,
291 HarmCategoryHateSpeech,
292 HarmCategorySexuallyExplicit,
293 HarmCategoryDangerousContent,
294 HarmCategoryCivicIntegrity,
295}
296
297#[derive(Debug, Deserialize)]
298#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
299enum GeminiHarmProbability {
300 HarmProbabilityUnspecified,
301 Negligible,
302 Low,
303 Medium,
304 High,
305}
306
307#[allow(dead_code)]
308#[derive(Debug, Deserialize)]
309struct GeminiCitationMetadata {
310 #[serde(rename = "citationSources")]
311 #[serde(skip_serializing_if = "Option::is_none")]
312 citation_sources: Option<Vec<GeminiCitationSource>>,
313}
314
315#[allow(dead_code)]
316#[derive(Debug, Deserialize)]
317struct GeminiCitationSource {
318 #[serde(rename = "startIndex")]
319 #[serde(skip_serializing_if = "Option::is_none")]
320 start_index: Option<i32>,
321 #[serde(rename = "endIndex")]
322 #[serde(skip_serializing_if = "Option::is_none")]
323 end_index: Option<i32>,
324 #[serde(skip_serializing_if = "Option::is_none")]
325 uri: Option<String>,
326 #[serde(skip_serializing_if = "Option::is_none")]
327 license: Option<String>,
328}
329
330#[derive(Debug, Deserialize)]
331struct GeminiUsageMetadata {
332 #[serde(rename = "promptTokenCount")]
333 #[serde(skip_serializing_if = "Option::is_none")]
334 prompt_token_count: Option<i32>,
335 #[serde(rename = "candidatesTokenCount")]
336 #[serde(skip_serializing_if = "Option::is_none")]
337 candidates_token_count: Option<i32>,
338 #[serde(rename = "totalTokenCount")]
339 #[serde(skip_serializing_if = "Option::is_none")]
340 total_token_count: Option<i32>,
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344struct GeminiFunctionResponse {
345 name: String,
346 response: GeminiResponseContent,
347}
348
349#[derive(Debug, Serialize, Deserialize)]
350struct GeminiResponseContent {
351 content: Value,
352}
353
354#[derive(Debug, Serialize, Deserialize)]
355struct GeminiExecutableCode {
356 language: String, code: String,
358}
359
360#[derive(Debug, Deserialize)]
361#[allow(dead_code)]
362struct GeminiContentResponse {
363 role: String,
364 parts: Vec<GeminiResponsePart>,
365}
366
367fn convert_messages(messages: Vec<AppMessage>) -> Vec<GeminiContent> {
368 messages
369 .into_iter()
370 .filter_map(|msg| match &msg.data {
371 crate::app::conversation::MessageData::User { content, .. } => {
372 let parts: Vec<GeminiRequestPart> = content
373 .iter()
374 .filter_map(|user_content| match user_content {
375 UserContent::Text { text } => {
376 Some(GeminiRequestPart::Text { text: text.clone() })
377 }
378 UserContent::CommandExecution {
379 command,
380 stdout,
381 stderr,
382 exit_code,
383 } => Some(GeminiRequestPart::Text {
384 text: UserContent::format_command_execution_as_xml(
385 command, stdout, stderr, *exit_code,
386 ),
387 }),
388 UserContent::AppCommand { .. } => {
389 None
391 }
392 })
393 .collect();
394
395 if parts.is_empty() {
397 None
398 } else {
399 Some(GeminiContent {
400 role: "user".to_string(),
401 parts,
402 })
403 }
404 }
405 crate::app::conversation::MessageData::Assistant { content, .. } => {
406 let parts: Vec<GeminiRequestPart> = content
407 .iter()
408 .filter_map(|assistant_content| match assistant_content {
409 AssistantContent::Text { text } => {
410 Some(GeminiRequestPart::Text { text: text.clone() })
411 }
412 AssistantContent::ToolCall { tool_call } => {
413 Some(GeminiRequestPart::FunctionCall {
414 function_call: GeminiFunctionCall {
415 name: tool_call.name.clone(),
416 args: tool_call.parameters.clone(),
417 },
418 })
419 }
420 AssistantContent::Thought { .. } => {
421 None
423 }
424 })
425 .collect();
426
427 Some(GeminiContent {
429 role: "model".to_string(),
430 parts,
431 })
432 }
433 crate::app::conversation::MessageData::Tool {
434 tool_use_id,
435 result,
436 ..
437 } => {
438 let result_value = match result {
440 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
441 _ => {
442 serde_json::to_value(result)
444 .unwrap_or_else(|_| Value::String(result.llm_format()))
445 }
446 };
447
448 let parts = vec![GeminiRequestPart::FunctionResponse {
449 function_response: GeminiFunctionResponse {
450 name: tool_use_id.clone(), response: GeminiResponseContent {
452 content: result_value,
453 },
454 },
455 }];
456
457 Some(GeminiContent {
458 role: "function".to_string(),
459 parts,
460 })
461 }
462 })
463 .collect()
464}
465
466fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
467 if let Some(prop_map_orig) = property_value.as_object() {
468 let mut simplified_prop = prop_map_orig.clone();
469
470 if simplified_prop.remove("additionalProperties").is_some() {
472 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
473 }
474
475 if let Some(type_val) = simplified_prop.get_mut("type") {
477 if let Some(type_array) = type_val.as_array() {
478 if let Some(primary_type) = type_array
479 .iter()
480 .find_map(|v| if !v.is_null() { v.as_str() } else { None })
481 {
482 *type_val = serde_json::Value::String(primary_type.to_string());
483 } else {
484 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
485 *type_val = serde_json::Value::String("string".to_string());
486 }
487 } else if !type_val.is_string() {
488 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
489 *type_val = serde_json::Value::String("string".to_string());
490 }
491 }
493
494 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string())) {
496 if let Some(format_val) = simplified_prop.get_mut("format") {
497 if format_val.as_str() == Some("uint64") {
498 *format_val = serde_json::Value::String("int64".to_string());
499 }
502 }
503 }
504
505 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
507 let should_remove_format = simplified_prop
508 .get("format")
509 .and_then(|f| f.as_str())
510 .map(|format_str| format_str != "enum" && format_str != "date-time")
511 .unwrap_or(false);
512
513 if should_remove_format {
514 if let Some(format_val) = simplified_prop.remove("format") {
515 if let Some(format_str) = format_val.as_str() {
516 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
517 }
518 }
519 }
520
521 if simplified_prop.remove("minLength").is_some() {
523 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
524 }
525 if simplified_prop.remove("maxLength").is_some() {
526 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
527 }
528 if simplified_prop.remove("pattern").is_some() {
529 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
530 }
531 }
532
533 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string())) {
535 if let Some(items_val) = simplified_prop.get_mut("items") {
536 *items_val =
537 simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
538 }
539 }
540
541 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string())) {
543 if let Some(Value::Object(props)) = simplified_prop.get_mut("properties") {
544 let simplified_nested_props: serde_json::Map<String, Value> = props
545 .iter()
546 .map(|(nested_key, nested_value)| {
547 (
548 nested_key.clone(),
549 simplify_property_schema(
550 &format!("{key}.{nested_key}"),
551 tool_name,
552 nested_value,
553 ),
554 )
555 })
556 .collect();
557 *props = simplified_nested_props;
558 }
559 }
560
561 serde_json::Value::Object(simplified_prop)
562 } else {
563 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
564 property_value.clone() }
566}
567
568fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
569 let function_declarations = tools
570 .into_iter()
571 .map(|tool| {
572 let simplified_properties = tool
574 .input_schema
575 .properties
576 .iter()
577 .map(|(key, value)| {
578 (
579 key.clone(),
580 simplify_property_schema(key, &tool.name, value),
581 )
582 })
583 .collect();
584
585 let parameters = GeminiParameterSchema {
587 schema_type: tool.input_schema.schema_type, properties: simplified_properties, required: tool.input_schema.required, };
591
592 GeminiFunctionDeclaration {
593 name: tool.name,
594 description: tool.description,
595 parameters,
596 }
597 })
598 .collect();
599
600 vec![GeminiTool {
601 function_declarations,
602 }]
603}
604
605fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
606 if let Some(feedback) = &response.prompt_feedback {
608 if let Some(reason) = &feedback.block_reason {
609 let details = format!(
610 "Prompt blocked due to {:?}. Safety ratings: {:?}",
611 reason, feedback.safety_ratings
612 );
613 warn!(target: "gemini::convert_response", "{}", details);
614 return Err(ApiError::RequestBlocked {
616 provider: "google".to_string(), details,
618 });
619 }
620 }
621
622 let candidates = match response.candidates {
624 Some(cands) => {
625 if cands.is_empty() {
626 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
629 return Err(ApiError::NoChoices {
631 provider: "google".to_string(),
632 });
633 }
634 cands }
636 None => {
637 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
638 return Err(ApiError::NoChoices {
640 provider: "google".to_string(),
641 });
642 }
643 };
644
645 let candidate = &candidates[0];
648
649 if let Some(reason) = &candidate.finish_reason {
651 match reason {
652 GeminiFinishReason::Stop => { }
653 GeminiFinishReason::MaxTokens => {
654 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
655 }
656 GeminiFinishReason::Safety => {
657 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
658 }
660 GeminiFinishReason::Recitation => {
661 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
662 }
663 GeminiFinishReason::MalformedFunctionCall => {
664 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
665 }
666 _ => {
667 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
668 }
669 }
670 }
671
672 if let Some(usage) = &response.usage_metadata {
674 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
675 usage.prompt_token_count, usage.candidates_token_count, usage.total_token_count);
676 }
677
678 let content: Vec<AssistantContent> = candidate
679 .content .parts .iter()
682 .filter_map(|part| { if part.thought {
685 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
686 match &part.data {
688 GeminiResponsePartData::Text { text } => {
689 Some(AssistantContent::Thought {
690 thought: ThoughtContent::Simple {
691 text: text.clone(),
692 },
693 })
694 }
695 _ => {
696 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
697 None
698 }
699 }
700 } else {
701 match &part.data {
703 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
704 text: text.clone(),
705 }),
706 GeminiResponsePartData::InlineData { inline_data } => {
707 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
708 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
709 }
710 GeminiResponsePartData::FunctionCall { function_call } => {
711 Some(AssistantContent::ToolCall {
712 tool_call: steer_tools::ToolCall {
713 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
715 parameters: function_call.args.clone(),
716 },
717 })
718 }
719 GeminiResponsePartData::FileData { file_data } => {
720 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
721 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
722 }
723 GeminiResponsePartData::ExecutableCode { executable_code } => {
724 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
725 executable_code.language);
726 Some(AssistantContent::Text {
727 text: format!(
728 "```{}
729{}
730```",
731 executable_code.language.to_lowercase(),
732 executable_code.code
733 ),
734 })
735 }
736 }
737 }
738 })
739 .collect();
740
741 Ok(CompletionResponse { content })
742}
743
744#[async_trait]
745impl Provider for GeminiClient {
746 fn name(&self) -> &'static str {
747 "google"
748 }
749
750 async fn complete(
751 &self,
752 model_id: &ModelId,
753 messages: Vec<AppMessage>,
754 system: Option<String>,
755 tools: Option<Vec<ToolSchema>>,
756 _call_options: Option<ModelParameters>,
757 token: CancellationToken,
758 ) -> Result<CompletionResponse, ApiError> {
759 let model_name = &model_id.1; let url = format!(
761 "{}/models/{}:generateContent?key={}",
762 GEMINI_API_BASE, model_name, self.api_key
763 );
764
765 let gemini_contents = convert_messages(messages);
766
767 let system_instruction = system.map(|instructions| GeminiSystemInstruction {
768 parts: vec![GeminiRequestPart::Text { text: instructions }],
769 });
770
771 let gemini_tools = tools.map(convert_tools);
772
773 let (temperature, top_p, max_output_tokens) = {
775 let opts = _call_options.as_ref();
776 (
777 opts.and_then(|o| o.temperature).or(Some(1.0)),
778 opts.and_then(|o| o.top_p).or(Some(0.95)),
779 opts.and_then(|o| o.max_tokens)
780 .map(|v| v as i32)
781 .or(Some(65536)),
782 )
783 };
784 let thinking_config = _call_options
785 .as_ref()
786 .and_then(|o| o.thinking_config)
787 .and_then(|tc| {
788 if !tc.enabled {
789 None
790 } else {
791 Some(GeminiThinkingConfig {
792 include_thoughts: tc.include_thoughts,
793 thinking_budget: tc.budget_tokens.map(|v| v as i32),
794 })
795 }
796 });
797
798 let request = GeminiRequest {
799 contents: gemini_contents,
800 system_instruction,
801 tools: gemini_tools,
802 generation_config: Some(GeminiGenerationConfig {
803 temperature,
804 top_p,
805 max_output_tokens,
806 thinking_config,
807 ..Default::default()
808 }),
809 };
810
811 let response = tokio::select! {
812 biased;
813 _ = token.cancelled() => {
814 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
815 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
816 }
817 res = self.client.post(&url).json(&request).send() => {
818 res.map_err(ApiError::Network)?
819 }
820 };
821 let status = response.status();
822
823 if status != StatusCode::OK {
824 let error_text = response.text().await.map_err(ApiError::Network)?;
825 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
826 return Err(match status.as_u16() {
827 401 | 403 => ApiError::AuthenticationFailed {
828 provider: self.name().to_string(),
829 details: error_text,
830 },
831 429 => ApiError::RateLimited {
832 provider: self.name().to_string(),
833 details: error_text,
834 },
835 400 | 404 => {
836 error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
837 ApiError::InvalidRequest {
838 provider: self.name().to_string(),
839 details: error_text,
840 }
841 } 500..=599 => ApiError::ServerError {
843 provider: self.name().to_string(),
844 status_code: status.as_u16(),
845 details: error_text,
846 },
847 _ => ApiError::Unknown {
848 provider: self.name().to_string(),
849 details: error_text,
850 },
851 });
852 }
853
854 let response_text = response.text().await.map_err(ApiError::Network)?;
855
856 match serde_json::from_str::<GeminiResponse>(&response_text) {
857 Ok(gemini_response) => {
858 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
859 provider: self.name().to_string(),
860 details: e.to_string(),
861 })
862 }
863 Err(e) => {
864 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
865 Err(ApiError::ResponseParsingError {
866 provider: self.name().to_string(),
867 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
868 })
869 }
870 }
871 }
872}
873
874#[cfg(test)]
875mod tests {
876 use super::*;
877 use serde_json::json;
878
879 #[test]
880 fn test_simplify_property_schema_removes_additional_properties() {
881 let property_value = json!({
882 "type": "object",
883 "properties": {
884 "name": {"type": "string"}
885 },
886 "additionalProperties": false
887 });
888
889 let expected = json!({
890 "type": "object",
891 "properties": {
892 "name": {"type": "string"}
893 }
894 });
895
896 let result = simplify_property_schema("testProp", "testTool", &property_value);
897 assert_eq!(result, expected);
898 }
899
900 #[test]
901 fn test_simplify_property_schema_removes_unsupported_string_formats() {
902 let property_value = json!({
903 "type": "string",
904 "format": "uri",
905 "minLength": 1,
906 "maxLength": 100,
907 "pattern": "^https://"
908 });
909
910 let expected = json!({
911 "type": "string"
912 });
913
914 let result = simplify_property_schema("urlProp", "testTool", &property_value);
915 assert_eq!(result, expected);
916 }
917
918 #[test]
919 fn test_simplify_property_schema_keeps_supported_string_formats() {
920 let property_value = json!({
921 "type": "string",
922 "format": "date-time"
923 });
924
925 let expected = json!({
926 "type": "string",
927 "format": "date-time"
928 });
929
930 let result = simplify_property_schema("dateProp", "testTool", &property_value);
931 assert_eq!(result, expected);
932 }
933
934 #[test]
935 fn test_simplify_property_schema_handles_array_types() {
936 let property_value = json!({
937 "type": ["string", "null"],
938 "format": "email"
939 });
940
941 let expected = json!({
942 "type": "string"
943 });
944
945 let result = simplify_property_schema("emailProp", "testTool", &property_value);
946 assert_eq!(result, expected);
947 }
948
949 #[test]
950 fn test_simplify_property_schema_recursively_handles_array_items() {
951 let property_value = json!({
952 "type": "array",
953 "items": {
954 "type": "object",
955 "properties": {
956 "url": {
957 "type": "string",
958 "format": "uri"
959 }
960 },
961 "additionalProperties": false
962 }
963 });
964
965 let expected = json!({
966 "type": "array",
967 "items": {
968 "type": "object",
969 "properties": {
970 "url": {
971 "type": "string"
972 }
973 }
974 }
975 });
976
977 let result = simplify_property_schema("linksProp", "testTool", &property_value);
978 assert_eq!(result, expected);
979 }
980
981 #[test]
982 fn test_simplify_property_schema_recursively_handles_nested_objects() {
983 let property_value = json!({
984 "type": "object",
985 "properties": {
986 "nested": {
987 "type": "object",
988 "properties": {
989 "field": {
990 "type": "string",
991 "format": "hostname"
992 }
993 },
994 "additionalProperties": true
995 }
996 },
997 "additionalProperties": false
998 });
999
1000 let expected = json!({
1001 "type": "object",
1002 "properties": {
1003 "nested": {
1004 "type": "object",
1005 "properties": {
1006 "field": {
1007 "type": "string"
1008 }
1009 }
1010 }
1011 }
1012 });
1013
1014 let result = simplify_property_schema("complexProp", "testTool", &property_value);
1015 assert_eq!(result, expected);
1016 }
1017
1018 #[test]
1019 fn test_simplify_property_schema_fixes_uint64_format() {
1020 let property_value = json!({
1021 "type": "integer",
1022 "format": "uint64"
1023 });
1024
1025 let expected = json!({
1026 "type": "integer",
1027 "format": "int64"
1028 });
1029
1030 let result = simplify_property_schema("idProp", "testTool", &property_value);
1031 assert_eq!(result, expected);
1032 }
1033
1034 #[test]
1035 fn test_convert_tools_integration() {
1036 use steer_tools::{InputSchema, ToolSchema};
1037
1038 let tool = ToolSchema {
1039 name: "create_issue".to_string(),
1040 description: "Create an issue".to_string(),
1041 input_schema: InputSchema {
1042 schema_type: "object".to_string(),
1043 properties: {
1044 let mut props = serde_json::Map::new();
1045 props.insert(
1046 "title".to_string(),
1047 json!({
1048 "type": "string",
1049 "minLength": 1
1050 }),
1051 );
1052 props.insert(
1053 "links".to_string(),
1054 json!({
1055 "type": "array",
1056 "items": {
1057 "type": "object",
1058 "properties": {
1059 "url": {
1060 "type": "string",
1061 "format": "uri"
1062 }
1063 },
1064 "additionalProperties": false
1065 }
1066 }),
1067 );
1068 props
1069 },
1070 required: vec!["title".to_string()],
1071 },
1072 };
1073
1074 let expected_tools = vec![GeminiTool {
1075 function_declarations: vec![GeminiFunctionDeclaration {
1076 name: "create_issue".to_string(),
1077 description: "Create an issue".to_string(),
1078 parameters: GeminiParameterSchema {
1079 schema_type: "object".to_string(),
1080 properties: {
1081 let mut props = serde_json::Map::new();
1082 props.insert(
1083 "title".to_string(),
1084 json!({
1085 "type": "string"
1086 }),
1087 );
1088 props.insert(
1089 "links".to_string(),
1090 json!({
1091 "type": "array",
1092 "items": {
1093 "type": "object",
1094 "properties": {
1095 "url": {
1096 "type": "string"
1097 }
1098 }
1099 }
1100 }),
1101 );
1102 props
1103 },
1104 required: vec!["title".to_string()],
1105 },
1106 }],
1107 }];
1108
1109 let result = convert_tools(vec![tool]);
1110 assert_eq!(result, expected_tools);
1111 }
1112}