steer_core/api/gemini/
client.rs

1use async_trait::async_trait;
2use reqwest::{Client as HttpClient, StatusCode};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use tokio_util::sync::CancellationToken;
6use tracing::{debug, error, info, warn};
7
8use crate::api::Model;
9use crate::api::error::ApiError;
10use crate::api::provider::{CompletionResponse, Provider};
11use crate::app::conversation::{
12    AssistantContent, Message as AppMessage, ThoughtContent, ToolResult, UserContent,
13};
14use steer_tools::ToolSchema;
15
16const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
17
18#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone for potential future use
19struct GeminiBlob {
20    #[serde(rename = "mimeType")]
21    mime_type: String,
22    data: String, // Assuming base64 encoded data
23}
24
25#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
26struct GeminiFileData {
27    #[serde(rename = "mimeType")]
28    mime_type: String,
29    #[serde(rename = "fileUri")]
30    file_uri: String,
31}
32
33#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
34struct GeminiCodeExecutionResult {
35    outcome: String, // e.g., "OK", "ERROR"
36                     // Potentially add output field later if needed
37}
38
39pub struct GeminiClient {
40    api_key: String,
41    client: HttpClient,
42}
43
44impl GeminiClient {
45    pub fn new(api_key: impl Into<String>) -> Self {
46        Self {
47            api_key: api_key.into(),
48            client: HttpClient::new(),
49        }
50    }
51}
52
53#[derive(Debug, Serialize)]
54struct GeminiRequest {
55    contents: Vec<GeminiContent>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    #[serde(rename = "systemInstruction")]
58    system_instruction: Option<GeminiSystemInstruction>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    tools: Option<Vec<GeminiTool>>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    #[serde(rename = "generationConfig")]
63    generation_config: Option<GeminiGenerationConfig>,
64}
65
66#[derive(Debug, Serialize, Default, Clone)]
67struct GeminiGenerationConfig {
68    #[serde(skip_serializing_if = "Option::is_none")]
69    #[serde(rename = "stopSequences")]
70    stop_sequences: Option<Vec<String>>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    #[serde(rename = "responseMimeType")]
73    response_mime_type: Option<GeminiMimeType>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    #[serde(rename = "candidateCount")]
76    candidate_count: Option<i32>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    #[serde(rename = "maxOutputTokens")]
79    max_output_tokens: Option<i32>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    temperature: Option<f32>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    #[serde(rename = "topP")]
84    top_p: Option<f32>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    #[serde(rename = "topK")]
87    top_k: Option<i32>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    #[serde(rename = "thinkingConfig")]
90    thinking_config: Option<GeminiThinkingConfig>,
91}
92
93#[derive(Debug, Serialize, Default, Clone)]
94struct GeminiThinkingConfig {
95    #[serde(skip_serializing_if = "Option::is_none")]
96    #[serde(rename = "includeThoughts")]
97    include_thoughts: Option<bool>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    #[serde(rename = "thinkingBudget")]
100    thinking_budget: Option<i32>,
101}
102
103#[allow(dead_code)]
104#[derive(Debug, Serialize, Clone)]
105#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
106enum GeminiMimeType {
107    MimeTypeUnspecified,
108    TextPlain,
109    ApplicationJson,
110}
111
112#[derive(Debug, Serialize)]
113struct GeminiSystemInstruction {
114    parts: Vec<GeminiRequestPart>,
115}
116
117#[derive(Debug, Serialize)]
118struct GeminiContent {
119    role: String,
120    parts: Vec<GeminiRequestPart>,
121}
122
123// Enum for parts used ONLY in requests
124#[derive(Debug, Serialize)]
125#[serde(untagged)]
126enum GeminiRequestPart {
127    Text {
128        text: String,
129    },
130    #[serde(rename = "functionCall")]
131    FunctionCall {
132        #[serde(rename = "functionCall")]
133        function_call: GeminiFunctionCall, // Used for model turns in history
134    },
135    #[serde(rename = "functionResponse")]
136    FunctionResponse {
137        #[serde(rename = "functionResponse")]
138        function_response: GeminiFunctionResponse, // Used for function/tool turns
139    },
140}
141
142// Enum for parts received ONLY in responses
143#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum GeminiResponsePartData {
146    Text {
147        text: String,
148    },
149    #[serde(rename = "inlineData")]
150    InlineData {
151        #[serde(rename = "inlineData")]
152        inline_data: GeminiBlob,
153    },
154    #[serde(rename = "functionCall")]
155    FunctionCall {
156        #[serde(rename = "functionCall")]
157        function_call: GeminiFunctionCall,
158    },
159    #[serde(rename = "fileData")]
160    FileData {
161        #[serde(rename = "fileData")]
162        file_data: GeminiFileData,
163    },
164    #[serde(rename = "executableCode")]
165    ExecutableCode {
166        #[serde(rename = "executableCode")]
167        executable_code: GeminiExecutableCode,
168    },
169    // Add other variants back here if needed
170}
171
172// 2. Change GeminiResponsePart to a struct
173#[derive(Debug, Deserialize)]
174struct GeminiResponsePart {
175    #[serde(default)] // Defaults to false if missing
176    thought: bool,
177
178    #[serde(flatten)] // Look for data fields directly in this struct's JSON
179    data: GeminiResponsePartData,
180}
181
182#[derive(Debug, Serialize, Deserialize)]
183struct GeminiFunctionCall {
184    name: String,
185    args: Value,
186}
187
188#[derive(Debug, Serialize, PartialEq)]
189struct GeminiTool {
190    #[serde(rename = "functionDeclarations")]
191    function_declarations: Vec<GeminiFunctionDeclaration>,
192}
193
194#[derive(Debug, Serialize, PartialEq)]
195struct GeminiFunctionDeclaration {
196    name: String,
197    description: String,
198    parameters: GeminiParameterSchema,
199}
200
201#[derive(Debug, Serialize, PartialEq)]
202struct GeminiParameterSchema {
203    #[serde(rename = "type")]
204    schema_type: String, // Typically "object"
205    properties: serde_json::Map<String, Value>,
206    required: Vec<String>,
207}
208
209#[derive(Debug, Deserialize)]
210struct GeminiResponse {
211    #[serde(rename = "candidates")]
212    #[serde(skip_serializing_if = "Option::is_none")]
213    candidates: Option<Vec<GeminiCandidate>>,
214    #[serde(rename = "promptFeedback")]
215    #[serde(skip_serializing_if = "Option::is_none")]
216    prompt_feedback: Option<GeminiPromptFeedback>,
217    #[serde(rename = "usageMetadata")]
218    #[serde(skip_serializing_if = "Option::is_none")]
219    usage_metadata: Option<GeminiUsageMetadata>,
220}
221
222#[derive(Debug, Deserialize)]
223struct GeminiCandidate {
224    content: GeminiContentResponse,
225    #[serde(rename = "finishReason")]
226    #[serde(skip_serializing_if = "Option::is_none")]
227    finish_reason: Option<GeminiFinishReason>,
228    #[serde(rename = "safetyRatings")]
229    #[serde(skip_serializing_if = "Option::is_none")]
230    safety_ratings: Option<Vec<GeminiSafetyRating>>,
231    #[serde(rename = "citationMetadata")]
232    #[serde(skip_serializing_if = "Option::is_none")]
233    citation_metadata: Option<GeminiCitationMetadata>,
234}
235
236#[derive(Debug, Deserialize)]
237#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
238enum GeminiFinishReason {
239    FinishReasonUnspecified,
240    Stop,
241    MaxTokens,
242    Safety,
243    Recitation,
244    Other,
245    #[serde(rename = "TOOL_CODE_ERROR")]
246    ToolCodeError,
247    #[serde(rename = "TOOL_EXECUTION_HALT")]
248    ToolExecutionHalt,
249    MalformedFunctionCall,
250}
251
252#[derive(Debug, Deserialize)]
253struct GeminiPromptFeedback {
254    #[serde(rename = "blockReason")]
255    #[serde(skip_serializing_if = "Option::is_none")]
256    block_reason: Option<GeminiBlockReason>,
257    #[serde(rename = "safetyRatings")]
258    #[serde(skip_serializing_if = "Option::is_none")]
259    safety_ratings: Option<Vec<GeminiSafetyRating>>,
260}
261
262#[derive(Debug, Deserialize)]
263#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
264enum GeminiBlockReason {
265    BlockReasonUnspecified,
266    Safety,
267    Other,
268}
269
270#[derive(Debug, Deserialize)]
271#[allow(dead_code)]
272struct GeminiSafetyRating {
273    category: GeminiHarmCategory,
274    probability: GeminiHarmProbability,
275    #[serde(default)] // Default to false if missing
276    blocked: bool,
277}
278
279#[derive(Debug, Deserialize, Serialize)] // Add Serialize for potential use in SafetySetting
280#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
281#[allow(clippy::enum_variant_names)]
282enum GeminiHarmCategory {
283    HarmCategoryUnspecified,
284    HarmCategoryDerogatory,
285    HarmCategoryToxicity,
286    HarmCategoryViolence,
287    HarmCategorySexual,
288    HarmCategoryMedical,
289    HarmCategoryDangerous,
290    HarmCategoryHarassment,
291    HarmCategoryHateSpeech,
292    HarmCategorySexuallyExplicit,
293    HarmCategoryDangerousContent,
294    HarmCategoryCivicIntegrity,
295}
296
297#[derive(Debug, Deserialize)]
298#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
299enum GeminiHarmProbability {
300    HarmProbabilityUnspecified,
301    Negligible,
302    Low,
303    Medium,
304    High,
305}
306
307#[allow(dead_code)]
308#[derive(Debug, Deserialize)]
309struct GeminiCitationMetadata {
310    #[serde(rename = "citationSources")]
311    #[serde(skip_serializing_if = "Option::is_none")]
312    citation_sources: Option<Vec<GeminiCitationSource>>,
313}
314
315#[allow(dead_code)]
316#[derive(Debug, Deserialize)]
317struct GeminiCitationSource {
318    #[serde(rename = "startIndex")]
319    #[serde(skip_serializing_if = "Option::is_none")]
320    start_index: Option<i32>,
321    #[serde(rename = "endIndex")]
322    #[serde(skip_serializing_if = "Option::is_none")]
323    end_index: Option<i32>,
324    #[serde(skip_serializing_if = "Option::is_none")]
325    uri: Option<String>,
326    #[serde(skip_serializing_if = "Option::is_none")]
327    license: Option<String>,
328}
329
330#[derive(Debug, Deserialize)]
331struct GeminiUsageMetadata {
332    #[serde(rename = "promptTokenCount")]
333    #[serde(skip_serializing_if = "Option::is_none")]
334    prompt_token_count: Option<i32>,
335    #[serde(rename = "candidatesTokenCount")]
336    #[serde(skip_serializing_if = "Option::is_none")]
337    candidates_token_count: Option<i32>,
338    #[serde(rename = "totalTokenCount")]
339    #[serde(skip_serializing_if = "Option::is_none")]
340    total_token_count: Option<i32>,
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344struct GeminiFunctionResponse {
345    name: String,
346    response: GeminiResponseContent,
347}
348
349#[derive(Debug, Serialize, Deserialize)]
350struct GeminiResponseContent {
351    content: Value,
352}
353
354#[derive(Debug, Serialize, Deserialize)]
355struct GeminiExecutableCode {
356    language: String, // e.g., PYTHON
357    code: String,
358}
359
360#[derive(Debug, Deserialize)]
361#[allow(dead_code)]
362struct GeminiContentResponse {
363    role: String,
364    parts: Vec<GeminiResponsePart>,
365}
366
367fn convert_messages(messages: Vec<AppMessage>) -> Vec<GeminiContent> {
368    messages
369        .into_iter()
370        .filter_map(|msg| match &msg.data {
371            crate::app::conversation::MessageData::User { content, .. } => {
372                let parts: Vec<GeminiRequestPart> = content
373                    .iter()
374                    .filter_map(|user_content| match user_content {
375                        UserContent::Text { text } => {
376                            Some(GeminiRequestPart::Text { text: text.clone() })
377                        }
378                        UserContent::CommandExecution {
379                            command,
380                            stdout,
381                            stderr,
382                            exit_code,
383                        } => Some(GeminiRequestPart::Text {
384                            text: UserContent::format_command_execution_as_xml(
385                                command, stdout, stderr, *exit_code,
386                            ),
387                        }),
388                        UserContent::AppCommand { .. } => {
389                            // Don't send app commands to the model - they're for local execution only
390                            None
391                        }
392                    })
393                    .collect();
394
395                // Only include the message if it has content after filtering
396                if parts.is_empty() {
397                    None
398                } else {
399                    Some(GeminiContent {
400                        role: "user".to_string(),
401                        parts,
402                    })
403                }
404            }
405            crate::app::conversation::MessageData::Assistant { content, .. } => {
406                let parts: Vec<GeminiRequestPart> = content
407                    .iter()
408                    .filter_map(|assistant_content| match assistant_content {
409                        AssistantContent::Text { text } => {
410                            Some(GeminiRequestPart::Text { text: text.clone() })
411                        }
412                        AssistantContent::ToolCall { tool_call } => {
413                            Some(GeminiRequestPart::FunctionCall {
414                                function_call: GeminiFunctionCall {
415                                    name: tool_call.name.clone(),
416                                    args: tool_call.parameters.clone(),
417                                },
418                            })
419                        }
420                        AssistantContent::Thought { .. } => {
421                            // Gemini doesn't send thought blocks in requests
422                            None
423                        }
424                    })
425                    .collect();
426
427                // Always include assistant messages (they should always have content)
428                Some(GeminiContent {
429                    role: "model".to_string(),
430                    parts,
431                })
432            }
433            crate::app::conversation::MessageData::Tool {
434                tool_use_id,
435                result,
436                ..
437            } => {
438                // Convert tool result to function response
439                let result_value = match result {
440                    ToolResult::Error(e) => Value::String(format!("Error: {e}")),
441                    _ => {
442                        // For all other variants, try to serialize as JSON
443                        serde_json::to_value(result)
444                            .unwrap_or_else(|_| Value::String(result.llm_format()))
445                    }
446                };
447
448                let parts = vec![GeminiRequestPart::FunctionResponse {
449                    function_response: GeminiFunctionResponse {
450                        name: tool_use_id.clone(), // Use tool_use_id as function name
451                        response: GeminiResponseContent {
452                            content: result_value,
453                        },
454                    },
455                }];
456
457                Some(GeminiContent {
458                    role: "function".to_string(),
459                    parts,
460                })
461            }
462        })
463        .collect()
464}
465
466fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
467    if let Some(prop_map_orig) = property_value.as_object() {
468        let mut simplified_prop = prop_map_orig.clone();
469
470        // Remove 'additionalProperties' as Gemini doesn't support it
471        if simplified_prop.remove("additionalProperties").is_some() {
472            debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
473        }
474
475        // Simplify 'type' field (handle arrays like ["string", "null"])
476        if let Some(type_val) = simplified_prop.get_mut("type") {
477            if let Some(type_array) = type_val.as_array() {
478                if let Some(primary_type) = type_array
479                    .iter()
480                    .find_map(|v| if !v.is_null() { v.as_str() } else { None })
481                {
482                    *type_val = serde_json::Value::String(primary_type.to_string());
483                } else {
484                    warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
485                    *type_val = serde_json::Value::String("string".to_string());
486                }
487            } else if !type_val.is_string() {
488                warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
489                *type_val = serde_json::Value::String("string".to_string());
490            }
491            // If it's already a simple string, do nothing.
492        }
493
494        // Fix integer format if necessary
495        if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string())) {
496            if let Some(format_val) = simplified_prop.get_mut("format") {
497                if format_val.as_str() == Some("uint64") {
498                    *format_val = serde_json::Value::String("int64".to_string());
499                    // Optionally remove minimum if Gemini doesn't support it with int64
500                    // simplified_prop.remove("minimum");
501                }
502            }
503        }
504
505        // For string types, Gemini only supports 'enum' and 'date-time' formats
506        if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
507            let should_remove_format = simplified_prop
508                .get("format")
509                .and_then(|f| f.as_str())
510                .map(|format_str| format_str != "enum" && format_str != "date-time")
511                .unwrap_or(false);
512
513            if should_remove_format {
514                if let Some(format_val) = simplified_prop.remove("format") {
515                    if let Some(format_str) = format_val.as_str() {
516                        debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
517                    }
518                }
519            }
520
521            // Also remove other string validation fields that might not be supported
522            if simplified_prop.remove("minLength").is_some() {
523                debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
524            }
525            if simplified_prop.remove("maxLength").is_some() {
526                debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
527            }
528            if simplified_prop.remove("pattern").is_some() {
529                debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
530            }
531        }
532
533        // Recursively simplify 'items' if this is an array type
534        if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string())) {
535            if let Some(items_val) = simplified_prop.get_mut("items") {
536                *items_val =
537                    simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
538            }
539        }
540
541        // Recursively simplify nested 'properties' if this is an object type
542        if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string())) {
543            if let Some(Value::Object(props)) = simplified_prop.get_mut("properties") {
544                let simplified_nested_props: serde_json::Map<String, Value> = props
545                    .iter()
546                    .map(|(nested_key, nested_value)| {
547                        (
548                            nested_key.clone(),
549                            simplify_property_schema(
550                                &format!("{key}.{nested_key}"),
551                                tool_name,
552                                nested_value,
553                            ),
554                        )
555                    })
556                    .collect();
557                *props = simplified_nested_props;
558            }
559        }
560
561        serde_json::Value::Object(simplified_prop)
562    } else {
563        warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
564        property_value.clone() // Return original if not an object
565    }
566}
567
568fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
569    let function_declarations = tools
570        .into_iter()
571        .map(|tool| {
572            // Simplify properties schema for Gemini using the helper function
573            let simplified_properties = tool
574                .input_schema
575                .properties
576                .iter()
577                .map(|(key, value)| {
578                    (
579                        key.clone(),
580                        simplify_property_schema(key, &tool.name, value),
581                    )
582                })
583                .collect();
584
585            // Construct the parameters object using the specific struct
586            let parameters = GeminiParameterSchema {
587                schema_type: tool.input_schema.schema_type, // Use schema_type field (usually "object")
588                properties: simplified_properties,          // Use simplified properties
589                required: tool.input_schema.required,       // Use required field
590            };
591
592            GeminiFunctionDeclaration {
593                name: tool.name,
594                description: tool.description,
595                parameters,
596            }
597        })
598        .collect();
599
600    vec![GeminiTool {
601        function_declarations,
602    }]
603}
604
605fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
606    // Log prompt feedback if present
607    if let Some(feedback) = &response.prompt_feedback {
608        if let Some(reason) = &feedback.block_reason {
609            let details = format!(
610                "Prompt blocked due to {:?}. Safety ratings: {:?}",
611                reason, feedback.safety_ratings
612            );
613            warn!(target: "gemini::convert_response", "{}", details);
614            // Return the specific RequestBlocked error
615            return Err(ApiError::RequestBlocked {
616                provider: "google".to_string(), // Assuming "google" is the provider name
617                details,
618            });
619        }
620    }
621
622    // Check candidates *after* checking for prompt blocking
623    let candidates = match response.candidates {
624        Some(cands) => {
625            if cands.is_empty() {
626                // If it was blocked, the previous check should have caught it.
627                // So, this means no candidates were generated for other reasons.
628                warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
629                // Use NoChoices error here
630                return Err(ApiError::NoChoices {
631                    provider: "google".to_string(),
632                });
633            }
634            cands // Return the non-empty vector
635        }
636        None => {
637            warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
638            // Use NoChoices error here as well
639            return Err(ApiError::NoChoices {
640                provider: "google".to_string(),
641            });
642        }
643    };
644
645    // For simplicity, still taking the first candidate. Multi-candidate handling could be added.
646    // Access candidates safely since we've checked it's not None or empty.
647    let candidate = &candidates[0];
648
649    // Log finish reason and safety ratings if present
650    if let Some(reason) = &candidate.finish_reason {
651        match reason {
652            GeminiFinishReason::Stop => { /* Normal completion */ }
653            GeminiFinishReason::MaxTokens => {
654                warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
655            }
656            GeminiFinishReason::Safety => {
657                warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
658                // Consider returning an error or modifying the response based on safety ratings
659            }
660            GeminiFinishReason::Recitation => {
661                warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
662            }
663            GeminiFinishReason::MalformedFunctionCall => {
664                warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
665            }
666            _ => {
667                info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
668            }
669        }
670    }
671
672    // Log usage metadata if present
673    if let Some(usage) = &response.usage_metadata {
674        debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
675               usage.prompt_token_count, usage.candidates_token_count, usage.total_token_count);
676    }
677
678    let content: Vec<AssistantContent> = candidate
679        .content // GeminiContentResponse
680        .parts   // Vec<GeminiResponsePart> (struct)
681        .iter()
682        .filter_map(|part| { // part is &GeminiResponsePart (struct)
683            // Check if this is a thought part first
684            if part.thought {
685                debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
686                // For thought parts, extract text content and create a Thought block
687                match &part.data {
688                    GeminiResponsePartData::Text { text } => {
689                        Some(AssistantContent::Thought {
690                            thought: ThoughtContent::Simple {
691                                text: text.clone(),
692                            },
693                        })
694                    }
695                    _ => {
696                        warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
697                        None
698                    }
699                }
700            } else {
701                // Regular (non-thought) content processing
702                match &part.data {
703                    GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
704                        text: text.clone(),
705                    }),
706                    GeminiResponsePartData::InlineData { inline_data } => {
707                        warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
708                        Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
709                    }
710                    GeminiResponsePartData::FunctionCall { function_call } => {
711                        Some(AssistantContent::ToolCall {
712                            tool_call: steer_tools::ToolCall {
713                                id: uuid::Uuid::new_v4().to_string(), // Generate a synthetic ID
714                                name: function_call.name.clone(),
715                                parameters: function_call.args.clone(),
716                            },
717                        })
718                    }
719                    GeminiResponsePartData::FileData { file_data } => {
720                        warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
721                        Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
722                    }
723                     GeminiResponsePartData::ExecutableCode { executable_code } => {
724                         info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
725                              executable_code.language);
726                         Some(AssistantContent::Text {
727                             text: format!(
728                                 "```{}
729{}
730```",
731                                 executable_code.language.to_lowercase(),
732                                 executable_code.code
733                             ),
734                         })
735                     }
736                }
737            }
738        })
739        .collect();
740
741    Ok(CompletionResponse { content })
742}
743
744#[async_trait]
745impl Provider for GeminiClient {
746    fn name(&self) -> &'static str {
747        "google"
748    }
749
750    async fn complete(
751        &self,
752        model: Model,
753        messages: Vec<AppMessage>,
754        system: Option<String>,
755        tools: Option<Vec<ToolSchema>>,
756        token: CancellationToken,
757    ) -> Result<CompletionResponse, ApiError> {
758        let model_name = model.as_ref();
759        let url = format!(
760            "{}/models/{}:generateContent?key={}",
761            GEMINI_API_BASE, model_name, self.api_key
762        );
763
764        let gemini_contents = convert_messages(messages);
765
766        let system_instruction = system.map(|instructions| GeminiSystemInstruction {
767            parts: vec![GeminiRequestPart::Text { text: instructions }],
768        });
769
770        let gemini_tools = tools.map(convert_tools);
771
772        let request = GeminiRequest {
773            contents: gemini_contents,
774            system_instruction,
775            tools: gemini_tools,
776            generation_config: Some(GeminiGenerationConfig {
777                temperature: Some(1.0),
778                top_p: Some(0.95),
779                max_output_tokens: Some(65536),
780                thinking_config: Some(GeminiThinkingConfig {
781                    include_thoughts: Some(true),
782                    thinking_budget: Some(8192),
783                }),
784                ..Default::default()
785            }),
786        };
787
788        let response = tokio::select! {
789            biased;
790            _ = token.cancelled() => {
791                debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
792                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
793            }
794            res = self.client.post(&url).json(&request).send() => {
795                res.map_err(ApiError::Network)?
796            }
797        };
798        let status = response.status();
799
800        if status != StatusCode::OK {
801            let error_text = response.text().await.map_err(ApiError::Network)?;
802            error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
803            return Err(match status.as_u16() {
804                401 | 403 => ApiError::AuthenticationFailed {
805                    provider: self.name().to_string(),
806                    details: error_text,
807                },
808                429 => ApiError::RateLimited {
809                    provider: self.name().to_string(),
810                    details: error_text,
811                },
812                400 | 404 => {
813                    error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
814                    ApiError::InvalidRequest {
815                        provider: self.name().to_string(),
816                        details: error_text,
817                    }
818                } // 404 might mean invalid model
819                500..=599 => ApiError::ServerError {
820                    provider: self.name().to_string(),
821                    status_code: status.as_u16(),
822                    details: error_text,
823                },
824                _ => ApiError::Unknown {
825                    provider: self.name().to_string(),
826                    details: error_text,
827                },
828            });
829        }
830
831        let response_text = response.text().await.map_err(ApiError::Network)?;
832
833        match serde_json::from_str::<GeminiResponse>(&response_text) {
834            Ok(gemini_response) => {
835                convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
836                    provider: self.name().to_string(),
837                    details: e.to_string(),
838                })
839            }
840            Err(e) => {
841                error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
842                Err(ApiError::ResponseParsingError {
843                    provider: self.name().to_string(),
844                    details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
845                })
846            }
847        }
848    }
849}
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854    use serde_json::json;
855
856    #[test]
857    fn test_simplify_property_schema_removes_additional_properties() {
858        let property_value = json!({
859            "type": "object",
860            "properties": {
861                "name": {"type": "string"}
862            },
863            "additionalProperties": false
864        });
865
866        let expected = json!({
867            "type": "object",
868            "properties": {
869                "name": {"type": "string"}
870            }
871        });
872
873        let result = simplify_property_schema("testProp", "testTool", &property_value);
874        assert_eq!(result, expected);
875    }
876
877    #[test]
878    fn test_simplify_property_schema_removes_unsupported_string_formats() {
879        let property_value = json!({
880            "type": "string",
881            "format": "uri",
882            "minLength": 1,
883            "maxLength": 100,
884            "pattern": "^https://"
885        });
886
887        let expected = json!({
888            "type": "string"
889        });
890
891        let result = simplify_property_schema("urlProp", "testTool", &property_value);
892        assert_eq!(result, expected);
893    }
894
895    #[test]
896    fn test_simplify_property_schema_keeps_supported_string_formats() {
897        let property_value = json!({
898            "type": "string",
899            "format": "date-time"
900        });
901
902        let expected = json!({
903            "type": "string",
904            "format": "date-time"
905        });
906
907        let result = simplify_property_schema("dateProp", "testTool", &property_value);
908        assert_eq!(result, expected);
909    }
910
911    #[test]
912    fn test_simplify_property_schema_handles_array_types() {
913        let property_value = json!({
914            "type": ["string", "null"],
915            "format": "email"
916        });
917
918        let expected = json!({
919            "type": "string"
920        });
921
922        let result = simplify_property_schema("emailProp", "testTool", &property_value);
923        assert_eq!(result, expected);
924    }
925
926    #[test]
927    fn test_simplify_property_schema_recursively_handles_array_items() {
928        let property_value = json!({
929            "type": "array",
930            "items": {
931                "type": "object",
932                "properties": {
933                    "url": {
934                        "type": "string",
935                        "format": "uri"
936                    }
937                },
938                "additionalProperties": false
939            }
940        });
941
942        let expected = json!({
943            "type": "array",
944            "items": {
945                "type": "object",
946                "properties": {
947                    "url": {
948                        "type": "string"
949                    }
950                }
951            }
952        });
953
954        let result = simplify_property_schema("linksProp", "testTool", &property_value);
955        assert_eq!(result, expected);
956    }
957
958    #[test]
959    fn test_simplify_property_schema_recursively_handles_nested_objects() {
960        let property_value = json!({
961            "type": "object",
962            "properties": {
963                "nested": {
964                    "type": "object",
965                    "properties": {
966                        "field": {
967                            "type": "string",
968                            "format": "hostname"
969                        }
970                    },
971                    "additionalProperties": true
972                }
973            },
974            "additionalProperties": false
975        });
976
977        let expected = json!({
978            "type": "object",
979            "properties": {
980                "nested": {
981                    "type": "object",
982                    "properties": {
983                        "field": {
984                            "type": "string"
985                        }
986                    }
987                }
988            }
989        });
990
991        let result = simplify_property_schema("complexProp", "testTool", &property_value);
992        assert_eq!(result, expected);
993    }
994
995    #[test]
996    fn test_simplify_property_schema_fixes_uint64_format() {
997        let property_value = json!({
998            "type": "integer",
999            "format": "uint64"
1000        });
1001
1002        let expected = json!({
1003            "type": "integer",
1004            "format": "int64"
1005        });
1006
1007        let result = simplify_property_schema("idProp", "testTool", &property_value);
1008        assert_eq!(result, expected);
1009    }
1010
1011    #[test]
1012    fn test_convert_tools_integration() {
1013        use steer_tools::{InputSchema, ToolSchema};
1014
1015        let tool = ToolSchema {
1016            name: "create_issue".to_string(),
1017            description: "Create an issue".to_string(),
1018            input_schema: InputSchema {
1019                schema_type: "object".to_string(),
1020                properties: {
1021                    let mut props = serde_json::Map::new();
1022                    props.insert(
1023                        "title".to_string(),
1024                        json!({
1025                            "type": "string",
1026                            "minLength": 1
1027                        }),
1028                    );
1029                    props.insert(
1030                        "links".to_string(),
1031                        json!({
1032                            "type": "array",
1033                            "items": {
1034                                "type": "object",
1035                                "properties": {
1036                                    "url": {
1037                                        "type": "string",
1038                                        "format": "uri"
1039                                    }
1040                                },
1041                                "additionalProperties": false
1042                            }
1043                        }),
1044                    );
1045                    props
1046                },
1047                required: vec!["title".to_string()],
1048            },
1049        };
1050
1051        let expected_tools = vec![GeminiTool {
1052            function_declarations: vec![GeminiFunctionDeclaration {
1053                name: "create_issue".to_string(),
1054                description: "Create an issue".to_string(),
1055                parameters: GeminiParameterSchema {
1056                    schema_type: "object".to_string(),
1057                    properties: {
1058                        let mut props = serde_json::Map::new();
1059                        props.insert(
1060                            "title".to_string(),
1061                            json!({
1062                                "type": "string"
1063                            }),
1064                        );
1065                        props.insert(
1066                            "links".to_string(),
1067                            json!({
1068                                "type": "array",
1069                                "items": {
1070                                    "type": "object",
1071                                    "properties": {
1072                                        "url": {
1073                                            "type": "string"
1074                                        }
1075                                    }
1076                                }
1077                            }),
1078                        );
1079                        props
1080                    },
1081                    required: vec!["title".to_string()],
1082                },
1083            }],
1084        }];
1085
1086        let result = convert_tools(vec![tool]);
1087        assert_eq!(result, expected_tools);
1088    }
1089}