1use async_trait::async_trait;
2use reqwest::{Client as HttpClient, StatusCode};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error, info, warn};
7
8use crate::api::Model;
9use crate::api::error::ApiError;
10use crate::api::provider::{CompletionResponse, Provider};
11use crate::app::conversation::{
12 AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
13};
14use steer_tools::ToolSchema;
15
16const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
17
18#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
20 #[serde(rename = "mimeType")]
21 mime_type: String,
22 data: String, }
24
25#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
27 #[serde(rename = "mimeType")]
28 mime_type: String,
29 #[serde(rename = "fileUri")]
30 file_uri: String,
31}
32
33#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiCodeExecutionResult {
35 outcome: String, }
38
39pub struct GeminiClient {
40 api_key: String,
41 client: HttpClient,
42}
43
44impl GeminiClient {
45 pub fn new(api_key: impl Into<String>) -> Self {
46 Self {
47 api_key: api_key.into(),
48 client: HttpClient::new(),
49 }
50 }
51}
52
53#[derive(Debug, Serialize)]
54struct GeminiRequest {
55 contents: Vec<GeminiContent>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 #[serde(rename = "systemInstruction")]
58 system_instruction: Option<GeminiSystemInstruction>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 tools: Option<Vec<GeminiTool>>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 #[serde(rename = "generationConfig")]
63 generation_config: Option<GeminiGenerationConfig>,
64}
65
66#[derive(Debug, Serialize, Default, Clone)]
67struct GeminiGenerationConfig {
68 #[serde(skip_serializing_if = "Option::is_none")]
69 #[serde(rename = "stopSequences")]
70 stop_sequences: Option<Vec<String>>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 #[serde(rename = "responseMimeType")]
73 response_mime_type: Option<GeminiMimeType>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 #[serde(rename = "candidateCount")]
76 candidate_count: Option<i32>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 #[serde(rename = "maxOutputTokens")]
79 max_output_tokens: Option<i32>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 temperature: Option<f32>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 #[serde(rename = "topP")]
84 top_p: Option<f32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 #[serde(rename = "topK")]
87 top_k: Option<i32>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 #[serde(rename = "thinkingConfig")]
90 thinking_config: Option<GeminiThinkingConfig>,
91}
92
93#[derive(Debug, Serialize, Default, Clone)]
94struct GeminiThinkingConfig {
95 #[serde(skip_serializing_if = "Option::is_none")]
96 #[serde(rename = "includeThoughts")]
97 include_thoughts: Option<bool>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 #[serde(rename = "thinkingBudget")]
100 thinking_budget: Option<i32>,
101}
102
103#[allow(dead_code)]
104#[derive(Debug, Serialize, Clone)]
105#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
106enum GeminiMimeType {
107 MimeTypeUnspecified,
108 TextPlain,
109 ApplicationJson,
110}
111
112#[derive(Debug, Serialize)]
113struct GeminiSystemInstruction {
114 parts: Vec<GeminiRequestPart>,
115}
116
117#[derive(Debug, Serialize)]
118struct GeminiContent {
119 role: String,
120 parts: Vec<GeminiRequestPart>,
121}
122
123#[derive(Debug, Serialize)]
125#[serde(untagged)]
126enum GeminiRequestPart {
127 Text {
128 text: String,
129 },
130 #[serde(rename = "functionCall")]
131 FunctionCall {
132 #[serde(rename = "functionCall")]
133 function_call: GeminiFunctionCall, },
135 #[serde(rename = "functionResponse")]
136 FunctionResponse {
137 #[serde(rename = "functionResponse")]
138 function_response: GeminiFunctionResponse, },
140}
141
142#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum GeminiResponsePartData {
146 Text {
147 text: String,
148 },
149 #[serde(rename = "inlineData")]
150 InlineData {
151 #[serde(rename = "inlineData")]
152 inline_data: GeminiBlob,
153 },
154 #[serde(rename = "functionCall")]
155 FunctionCall {
156 #[serde(rename = "functionCall")]
157 function_call: GeminiFunctionCall,
158 },
159 #[serde(rename = "fileData")]
160 FileData {
161 #[serde(rename = "fileData")]
162 file_data: GeminiFileData,
163 },
164 #[serde(rename = "executableCode")]
165 ExecutableCode {
166 #[serde(rename = "executableCode")]
167 executable_code: GeminiExecutableCode,
168 },
169 }
171
172#[derive(Debug, Deserialize)]
174struct GeminiResponsePart {
175 #[serde(default)] thought: bool,
177
178 #[serde(flatten)] data: GeminiResponsePartData,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183struct GeminiFunctionCall {
184 name: String,
185 args: Value,
186}
187
188#[derive(Debug, Serialize, PartialEq)]
189struct GeminiTool {
190 #[serde(rename = "functionDeclarations")]
191 function_declarations: Vec<GeminiFunctionDeclaration>,
192}
193
194#[derive(Debug, Serialize, PartialEq)]
195struct GeminiFunctionDeclaration {
196 name: String,
197 description: String,
198 parameters: GeminiParameterSchema,
199}
200
201#[derive(Debug, Serialize, PartialEq)]
202struct GeminiParameterSchema {
203 #[serde(rename = "type")]
204 schema_type: String, properties: serde_json::Map<String, Value>,
206 required: Vec<String>,
207}
208
209#[derive(Debug, Deserialize)]
210struct GeminiResponse {
211 #[serde(rename = "candidates")]
212 #[serde(skip_serializing_if = "Option::is_none")]
213 candidates: Option<Vec<GeminiCandidate>>,
214 #[serde(rename = "promptFeedback")]
215 #[serde(skip_serializing_if = "Option::is_none")]
216 prompt_feedback: Option<GeminiPromptFeedback>,
217 #[serde(rename = "usageMetadata")]
218 #[serde(skip_serializing_if = "Option::is_none")]
219 usage_metadata: Option<GeminiUsageMetadata>,
220}
221
222#[derive(Debug, Deserialize)]
223struct GeminiCandidate {
224 content: GeminiContentResponse,
225 #[serde(rename = "finishReason")]
226 #[serde(skip_serializing_if = "Option::is_none")]
227 finish_reason: Option<GeminiFinishReason>,
228 #[serde(rename = "safetyRatings")]
229 #[serde(skip_serializing_if = "Option::is_none")]
230 safety_ratings: Option<Vec<GeminiSafetyRating>>,
231 #[serde(rename = "citationMetadata")]
232 #[serde(skip_serializing_if = "Option::is_none")]
233 citation_metadata: Option<GeminiCitationMetadata>,
234}
235
236#[derive(Debug, Deserialize)]
237#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
238enum GeminiFinishReason {
239 FinishReasonUnspecified,
240 Stop,
241 MaxTokens,
242 Safety,
243 Recitation,
244 Other,
245 #[serde(rename = "TOOL_CODE_ERROR")]
246 ToolCodeError,
247 #[serde(rename = "TOOL_EXECUTION_HALT")]
248 ToolExecutionHalt,
249 MalformedFunctionCall,
250}
251
252#[derive(Debug, Deserialize)]
253struct GeminiPromptFeedback {
254 #[serde(rename = "blockReason")]
255 #[serde(skip_serializing_if = "Option::is_none")]
256 block_reason: Option<GeminiBlockReason>,
257 #[serde(rename = "safetyRatings")]
258 #[serde(skip_serializing_if = "Option::is_none")]
259 safety_ratings: Option<Vec<GeminiSafetyRating>>,
260}
261
262#[derive(Debug, Deserialize)]
263#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
264enum GeminiBlockReason {
265 BlockReasonUnspecified,
266 Safety,
267 Other,
268}
269
270#[derive(Debug, Deserialize)]
271#[allow(dead_code)]
272struct GeminiSafetyRating {
273 category: GeminiHarmCategory,
274 probability: GeminiHarmProbability,
275 #[serde(default)] blocked: bool,
277}
278
279#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
281#[allow(clippy::enum_variant_names)]
282enum GeminiHarmCategory {
283 HarmCategoryUnspecified,
284 HarmCategoryDerogatory,
285 HarmCategoryToxicity,
286 HarmCategoryViolence,
287 HarmCategorySexual,
288 HarmCategoryMedical,
289 HarmCategoryDangerous,
290 HarmCategoryHarassment,
291 HarmCategoryHateSpeech,
292 HarmCategorySexuallyExplicit,
293 HarmCategoryDangerousContent,
294 HarmCategoryCivicIntegrity,
295}
296
297#[derive(Debug, Deserialize)]
298#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
299enum GeminiHarmProbability {
300 HarmProbabilityUnspecified,
301 Negligible,
302 Low,
303 Medium,
304 High,
305}
306
307#[allow(dead_code)]
308#[derive(Debug, Deserialize)]
309struct GeminiCitationMetadata {
310 #[serde(rename = "citationSources")]
311 #[serde(skip_serializing_if = "Option::is_none")]
312 citation_sources: Option<Vec<GeminiCitationSource>>,
313}
314
315#[allow(dead_code)]
316#[derive(Debug, Deserialize)]
317struct GeminiCitationSource {
318 #[serde(rename = "startIndex")]
319 #[serde(skip_serializing_if = "Option::is_none")]
320 start_index: Option<i32>,
321 #[serde(rename = "endIndex")]
322 #[serde(skip_serializing_if = "Option::is_none")]
323 end_index: Option<i32>,
324 #[serde(skip_serializing_if = "Option::is_none")]
325 uri: Option<String>,
326 #[serde(skip_serializing_if = "Option::is_none")]
327 license: Option<String>,
328}
329
330#[derive(Debug, Deserialize)]
331struct GeminiUsageMetadata {
332 #[serde(rename = "promptTokenCount")]
333 #[serde(skip_serializing_if = "Option::is_none")]
334 prompt_token_count: Option<i32>,
335 #[serde(rename = "candidatesTokenCount")]
336 #[serde(skip_serializing_if = "Option::is_none")]
337 candidates_token_count: Option<i32>,
338 #[serde(rename = "totalTokenCount")]
339 #[serde(skip_serializing_if = "Option::is_none")]
340 total_token_count: Option<i32>,
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344struct GeminiFunctionResponse {
345 name: String,
346 response: GeminiResponseContent,
347}
348
349#[derive(Debug, Serialize, Deserialize)]
350struct GeminiResponseContent {
351 content: Value,
352}
353
354#[derive(Debug, Serialize, Deserialize)]
355struct GeminiExecutableCode {
356 language: String, code: String,
358}
359
360#[derive(Debug, Deserialize)]
361#[allow(dead_code)]
362struct GeminiContentResponse {
363 role: String,
364 parts: Vec<GeminiResponsePart>,
365}
366
367fn convert_messages(messages: Vec<AppMessage>) -> Vec<GeminiContent> {
368 messages
369 .into_iter()
370 .filter_map(|msg| match &msg.data {
371 crate::app::conversation::MessageData::User { content, .. } => {
372 let parts: Vec<GeminiRequestPart> = content
373 .iter()
374 .filter_map(|user_content| match user_content {
375 UserContent::Text { text } => {
376 Some(GeminiRequestPart::Text { text: text.clone() })
377 }
378 UserContent::CommandExecution {
379 command,
380 stdout,
381 stderr,
382 exit_code,
383 } => Some(GeminiRequestPart::Text {
384 text: UserContent::format_command_execution_as_xml(
385 command, stdout, stderr, *exit_code,
386 ),
387 }),
388 UserContent::AppCommand { .. } => {
389 None
391 }
392 })
393 .collect();
394
395 if parts.is_empty() {
397 None
398 } else {
399 Some(GeminiContent {
400 role: "user".to_string(),
401 parts,
402 })
403 }
404 }
405 crate::app::conversation::MessageData::Assistant { content, .. } => {
406 let parts: Vec<GeminiRequestPart> = content
407 .iter()
408 .filter_map(|assistant_content| match assistant_content {
409 AssistantContent::Text { text } => {
410 Some(GeminiRequestPart::Text { text: text.clone() })
411 }
412 AssistantContent::ToolCall { tool_call } => {
413 Some(GeminiRequestPart::FunctionCall {
414 function_call: GeminiFunctionCall {
415 name: tool_call.name.clone(),
416 args: tool_call.parameters.clone(),
417 },
418 })
419 }
420 AssistantContent::Thought { .. } => {
421 None
423 }
424 })
425 .collect();
426
427 Some(GeminiContent {
429 role: "model".to_string(),
430 parts,
431 })
432 }
433 crate::app::conversation::MessageData::Tool {
434 tool_use_id,
435 result,
436 ..
437 } => {
438 let result_value = match result {
440 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
441 _ => {
442 serde_json::to_value(result)
444 .unwrap_or_else(|_| Value::String(result.llm_format()))
445 }
446 };
447
448 let parts = vec![GeminiRequestPart::FunctionResponse {
449 function_response: GeminiFunctionResponse {
450 name: tool_use_id.clone(), response: GeminiResponseContent {
452 content: result_value,
453 },
454 },
455 }];
456
457 Some(GeminiContent {
458 role: "function".to_string(),
459 parts,
460 })
461 }
462 })
463 .collect()
464}
465
466fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
467 if let Some(prop_map_orig) = property_value.as_object() {
468 let mut simplified_prop = prop_map_orig.clone();
469
470 if simplified_prop.remove("additionalProperties").is_some() {
472 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
473 }
474
475 if let Some(type_val) = simplified_prop.get_mut("type") {
477 if let Some(type_array) = type_val.as_array() {
478 if let Some(primary_type) = type_array
479 .iter()
480 .find_map(|v| if !v.is_null() { v.as_str() } else { None })
481 {
482 *type_val = serde_json::Value::String(primary_type.to_string());
483 } else {
484 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
485 *type_val = serde_json::Value::String("string".to_string());
486 }
487 } else if !type_val.is_string() {
488 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
489 *type_val = serde_json::Value::String("string".to_string());
490 }
491 }
493
494 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string())) {
496 if let Some(format_val) = simplified_prop.get_mut("format") {
497 if format_val.as_str() == Some("uint64") {
498 *format_val = serde_json::Value::String("int64".to_string());
499 }
502 }
503 }
504
505 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
507 let should_remove_format = simplified_prop
508 .get("format")
509 .and_then(|f| f.as_str())
510 .map(|format_str| format_str != "enum" && format_str != "date-time")
511 .unwrap_or(false);
512
513 if should_remove_format {
514 if let Some(format_val) = simplified_prop.remove("format") {
515 if let Some(format_str) = format_val.as_str() {
516 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
517 }
518 }
519 }
520
521 if simplified_prop.remove("minLength").is_some() {
523 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
524 }
525 if simplified_prop.remove("maxLength").is_some() {
526 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
527 }
528 if simplified_prop.remove("pattern").is_some() {
529 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
530 }
531 }
532
533 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string())) {
535 if let Some(items_val) = simplified_prop.get_mut("items") {
536 *items_val =
537 simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
538 }
539 }
540
541 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string())) {
543 if let Some(Value::Object(props)) = simplified_prop.get_mut("properties") {
544 let simplified_nested_props: serde_json::Map<String, Value> = props
545 .iter()
546 .map(|(nested_key, nested_value)| {
547 (
548 nested_key.clone(),
549 simplify_property_schema(
550 &format!("{key}.{nested_key}"),
551 tool_name,
552 nested_value,
553 ),
554 )
555 })
556 .collect();
557 *props = simplified_nested_props;
558 }
559 }
560
561 serde_json::Value::Object(simplified_prop)
562 } else {
563 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
564 property_value.clone() }
566}
567
568fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
569 let function_declarations = tools
570 .into_iter()
571 .map(|tool| {
572 let simplified_properties = tool
574 .input_schema
575 .properties
576 .iter()
577 .map(|(key, value)| {
578 (
579 key.clone(),
580 simplify_property_schema(key, &tool.name, value),
581 )
582 })
583 .collect();
584
585 let parameters = GeminiParameterSchema {
587 schema_type: tool.input_schema.schema_type, properties: simplified_properties, required: tool.input_schema.required, };
591
592 GeminiFunctionDeclaration {
593 name: tool.name,
594 description: tool.description,
595 parameters,
596 }
597 })
598 .collect();
599
600 vec![GeminiTool {
601 function_declarations,
602 }]
603}
604
605fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
606 if let Some(feedback) = &response.prompt_feedback {
608 if let Some(reason) = &feedback.block_reason {
609 let details = format!(
610 "Prompt blocked due to {:?}. Safety ratings: {:?}",
611 reason, feedback.safety_ratings
612 );
613 warn!(target: "gemini::convert_response", "{}", details);
614 return Err(ApiError::RequestBlocked {
616 provider: "google".to_string(), details,
618 });
619 }
620 }
621
622 let candidates = match response.candidates {
624 Some(cands) => {
625 if cands.is_empty() {
626 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
629 return Err(ApiError::NoChoices {
631 provider: "google".to_string(),
632 });
633 }
634 cands }
636 None => {
637 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
638 return Err(ApiError::NoChoices {
640 provider: "google".to_string(),
641 });
642 }
643 };
644
645 let candidate = &candidates[0];
648
649 if let Some(reason) = &candidate.finish_reason {
651 match reason {
652 GeminiFinishReason::Stop => { }
653 GeminiFinishReason::MaxTokens => {
654 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
655 }
656 GeminiFinishReason::Safety => {
657 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
658 }
660 GeminiFinishReason::Recitation => {
661 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
662 }
663 GeminiFinishReason::MalformedFunctionCall => {
664 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
665 }
666 _ => {
667 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
668 }
669 }
670 }
671
672 if let Some(usage) = &response.usage_metadata {
674 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
675 usage.prompt_token_count, usage.candidates_token_count, usage.total_token_count);
676 }
677
678 let content: Vec<AssistantContent> = candidate
679 .content .parts .iter()
682 .filter_map(|part| { if part.thought {
685 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
686 match &part.data {
688 GeminiResponsePartData::Text { text } => {
689 Some(AssistantContent::Thought {
690 thought: ThoughtContent::Simple {
691 text: text.clone(),
692 },
693 })
694 }
695 _ => {
696 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
697 None
698 }
699 }
700 } else {
701 match &part.data {
703 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
704 text: text.clone(),
705 }),
706 GeminiResponsePartData::InlineData { inline_data } => {
707 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
708 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
709 }
710 GeminiResponsePartData::FunctionCall { function_call } => {
711 Some(AssistantContent::ToolCall {
712 tool_call: steer_tools::ToolCall {
713 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
715 parameters: function_call.args.clone(),
716 },
717 })
718 }
719 GeminiResponsePartData::FileData { file_data } => {
720 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
721 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
722 }
723 GeminiResponsePartData::ExecutableCode { executable_code } => {
724 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
725 executable_code.language);
726 Some(AssistantContent::Text {
727 text: format!(
728 "```{}
729{}
730```",
731 executable_code.language.to_lowercase(),
732 executable_code.code
733 ),
734 })
735 }
736 }
737 }
738 })
739 .collect();
740
741 Ok(CompletionResponse { content })
742}
743
744#[async_trait]
745impl Provider for GeminiClient {
746 fn name(&self) -> &'static str {
747 "google"
748 }
749
750 async fn complete(
751 &self,
752 model: Model,
753 messages: Vec<AppMessage>,
754 system: Option<String>,
755 tools: Option<Vec<ToolSchema>>,
756 token: CancellationToken,
757 ) -> Result<CompletionResponse, ApiError> {
758 let model_name = model.as_ref();
759 let url = format!(
760 "{}/models/{}:generateContent?key={}",
761 GEMINI_API_BASE, model_name, self.api_key
762 );
763
764 let gemini_contents = convert_messages(messages);
765
766 let system_instruction = system.map(|instructions| GeminiSystemInstruction {
767 parts: vec![GeminiRequestPart::Text { text: instructions }],
768 });
769
770 let gemini_tools = tools.map(convert_tools);
771
772 let request = GeminiRequest {
773 contents: gemini_contents,
774 system_instruction,
775 tools: gemini_tools,
776 generation_config: Some(GeminiGenerationConfig {
777 temperature: Some(1.0),
778 top_p: Some(0.95),
779 max_output_tokens: Some(65536),
780 thinking_config: Some(GeminiThinkingConfig {
781 include_thoughts: Some(true),
782 thinking_budget: Some(8192),
783 }),
784 ..Default::default()
785 }),
786 };
787
788 let response = tokio::select! {
789 biased;
790 _ = token.cancelled() => {
791 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
792 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
793 }
794 res = self.client.post(&url).json(&request).send() => {
795 res.map_err(ApiError::Network)?
796 }
797 };
798 let status = response.status();
799
800 if status != StatusCode::OK {
801 let error_text = response.text().await.map_err(ApiError::Network)?;
802 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
803 return Err(match status.as_u16() {
804 401 | 403 => ApiError::AuthenticationFailed {
805 provider: self.name().to_string(),
806 details: error_text,
807 },
808 429 => ApiError::RateLimited {
809 provider: self.name().to_string(),
810 details: error_text,
811 },
812 400 | 404 => {
813 error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
814 ApiError::InvalidRequest {
815 provider: self.name().to_string(),
816 details: error_text,
817 }
818 } 500..=599 => ApiError::ServerError {
820 provider: self.name().to_string(),
821 status_code: status.as_u16(),
822 details: error_text,
823 },
824 _ => ApiError::Unknown {
825 provider: self.name().to_string(),
826 details: error_text,
827 },
828 });
829 }
830
831 let response_text = response.text().await.map_err(ApiError::Network)?;
832
833 match serde_json::from_str::<GeminiResponse>(&response_text) {
834 Ok(gemini_response) => {
835 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
836 provider: self.name().to_string(),
837 details: e.to_string(),
838 })
839 }
840 Err(e) => {
841 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
842 Err(ApiError::ResponseParsingError {
843 provider: self.name().to_string(),
844 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
845 })
846 }
847 }
848 }
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854 use serde_json::json;
855
856 #[test]
857 fn test_simplify_property_schema_removes_additional_properties() {
858 let property_value = json!({
859 "type": "object",
860 "properties": {
861 "name": {"type": "string"}
862 },
863 "additionalProperties": false
864 });
865
866 let expected = json!({
867 "type": "object",
868 "properties": {
869 "name": {"type": "string"}
870 }
871 });
872
873 let result = simplify_property_schema("testProp", "testTool", &property_value);
874 assert_eq!(result, expected);
875 }
876
877 #[test]
878 fn test_simplify_property_schema_removes_unsupported_string_formats() {
879 let property_value = json!({
880 "type": "string",
881 "format": "uri",
882 "minLength": 1,
883 "maxLength": 100,
884 "pattern": "^https://"
885 });
886
887 let expected = json!({
888 "type": "string"
889 });
890
891 let result = simplify_property_schema("urlProp", "testTool", &property_value);
892 assert_eq!(result, expected);
893 }
894
895 #[test]
896 fn test_simplify_property_schema_keeps_supported_string_formats() {
897 let property_value = json!({
898 "type": "string",
899 "format": "date-time"
900 });
901
902 let expected = json!({
903 "type": "string",
904 "format": "date-time"
905 });
906
907 let result = simplify_property_schema("dateProp", "testTool", &property_value);
908 assert_eq!(result, expected);
909 }
910
911 #[test]
912 fn test_simplify_property_schema_handles_array_types() {
913 let property_value = json!({
914 "type": ["string", "null"],
915 "format": "email"
916 });
917
918 let expected = json!({
919 "type": "string"
920 });
921
922 let result = simplify_property_schema("emailProp", "testTool", &property_value);
923 assert_eq!(result, expected);
924 }
925
926 #[test]
927 fn test_simplify_property_schema_recursively_handles_array_items() {
928 let property_value = json!({
929 "type": "array",
930 "items": {
931 "type": "object",
932 "properties": {
933 "url": {
934 "type": "string",
935 "format": "uri"
936 }
937 },
938 "additionalProperties": false
939 }
940 });
941
942 let expected = json!({
943 "type": "array",
944 "items": {
945 "type": "object",
946 "properties": {
947 "url": {
948 "type": "string"
949 }
950 }
951 }
952 });
953
954 let result = simplify_property_schema("linksProp", "testTool", &property_value);
955 assert_eq!(result, expected);
956 }
957
958 #[test]
959 fn test_simplify_property_schema_recursively_handles_nested_objects() {
960 let property_value = json!({
961 "type": "object",
962 "properties": {
963 "nested": {
964 "type": "object",
965 "properties": {
966 "field": {
967 "type": "string",
968 "format": "hostname"
969 }
970 },
971 "additionalProperties": true
972 }
973 },
974 "additionalProperties": false
975 });
976
977 let expected = json!({
978 "type": "object",
979 "properties": {
980 "nested": {
981 "type": "object",
982 "properties": {
983 "field": {
984 "type": "string"
985 }
986 }
987 }
988 }
989 });
990
991 let result = simplify_property_schema("complexProp", "testTool", &property_value);
992 assert_eq!(result, expected);
993 }
994
995 #[test]
996 fn test_simplify_property_schema_fixes_uint64_format() {
997 let property_value = json!({
998 "type": "integer",
999 "format": "uint64"
1000 });
1001
1002 let expected = json!({
1003 "type": "integer",
1004 "format": "int64"
1005 });
1006
1007 let result = simplify_property_schema("idProp", "testTool", &property_value);
1008 assert_eq!(result, expected);
1009 }
1010
1011 #[test]
1012 fn test_convert_tools_integration() {
1013 use steer_tools::{InputSchema, ToolSchema};
1014
1015 let tool = ToolSchema {
1016 name: "create_issue".to_string(),
1017 description: "Create an issue".to_string(),
1018 input_schema: InputSchema {
1019 schema_type: "object".to_string(),
1020 properties: {
1021 let mut props = serde_json::Map::new();
1022 props.insert(
1023 "title".to_string(),
1024 json!({
1025 "type": "string",
1026 "minLength": 1
1027 }),
1028 );
1029 props.insert(
1030 "links".to_string(),
1031 json!({
1032 "type": "array",
1033 "items": {
1034 "type": "object",
1035 "properties": {
1036 "url": {
1037 "type": "string",
1038 "format": "uri"
1039 }
1040 },
1041 "additionalProperties": false
1042 }
1043 }),
1044 );
1045 props
1046 },
1047 required: vec!["title".to_string()],
1048 },
1049 };
1050
1051 let expected_tools = vec![GeminiTool {
1052 function_declarations: vec![GeminiFunctionDeclaration {
1053 name: "create_issue".to_string(),
1054 description: "Create an issue".to_string(),
1055 parameters: GeminiParameterSchema {
1056 schema_type: "object".to_string(),
1057 properties: {
1058 let mut props = serde_json::Map::new();
1059 props.insert(
1060 "title".to_string(),
1061 json!({
1062 "type": "string"
1063 }),
1064 );
1065 props.insert(
1066 "links".to_string(),
1067 json!({
1068 "type": "array",
1069 "items": {
1070 "type": "object",
1071 "properties": {
1072 "url": {
1073 "type": "string"
1074 }
1075 }
1076 }
1077 }),
1078 );
1079 props
1080 },
1081 required: vec!["title".to_string()],
1082 },
1083 }],
1084 }];
1085
1086 let result = convert_tools(vec![tool]);
1087 assert_eq!(result, expected_tools);
1088 }
1089}