Skip to main content

steer_core/api/gemini/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11    CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::map_http_status_to_api_error;
15use crate::app::SystemContext;
16use crate::app::conversation::{
17    AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
18    ToolResult, UserContent,
19};
20use crate::config::model::{ModelId, ModelParameters};
21use steer_tools::ToolSchema;
22
23const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
24
25#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone for potential future use
26struct GeminiBlob {
27    #[serde(rename = "mimeType")]
28    mime_type: String,
29    data: String, // Assuming base64 encoded data
30}
31
32#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
33struct GeminiFileData {
34    #[serde(rename = "mimeType")]
35    mime_type: String,
36    #[serde(rename = "fileUri")]
37    file_uri: String,
38}
39
40#[expect(dead_code)]
41#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
42struct GeminiCodeExecutionResult {
43    outcome: String, // e.g., "OK", "ERROR"
44                     // Potentially add output field later if needed
45}
46
47pub struct GeminiClient {
48    api_key: String,
49    client: HttpClient,
50}
51
52impl GeminiClient {
53    pub fn new(api_key: impl Into<String>) -> Self {
54        Self {
55            api_key: api_key.into(),
56            client: HttpClient::new(),
57        }
58    }
59}
60
61#[derive(Debug, Serialize)]
62struct GeminiRequest {
63    contents: Vec<GeminiContent>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    #[serde(rename = "systemInstruction")]
66    system_instruction: Option<GeminiSystemInstruction>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    tools: Option<Vec<GeminiTool>>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    #[serde(rename = "generationConfig")]
71    generation_config: Option<GeminiGenerationConfig>,
72}
73
74#[derive(Debug, Serialize, Default, Clone)]
75struct GeminiGenerationConfig {
76    #[serde(skip_serializing_if = "Option::is_none")]
77    #[serde(rename = "stopSequences")]
78    stop_sequences: Option<Vec<String>>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    #[serde(rename = "responseMimeType")]
81    response_mime_type: Option<GeminiMimeType>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    #[serde(rename = "candidateCount")]
84    candidate_count: Option<i32>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    #[serde(rename = "maxOutputTokens")]
87    max_output_tokens: Option<i32>,
88    #[serde(skip_serializing_if = "Option::is_none")]
89    temperature: Option<f32>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    #[serde(rename = "topP")]
92    top_p: Option<f32>,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    #[serde(rename = "topK")]
95    top_k: Option<i32>,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    #[serde(rename = "thinkingConfig")]
98    thinking_config: Option<GeminiThinkingConfig>,
99}
100
101#[derive(Debug, Serialize, Default, Clone)]
102struct GeminiThinkingConfig {
103    #[serde(skip_serializing_if = "Option::is_none")]
104    #[serde(rename = "includeThoughts")]
105    include_thoughts: Option<bool>,
106    #[serde(skip_serializing_if = "Option::is_none")]
107    #[serde(rename = "thinkingBudget")]
108    thinking_budget: Option<i32>,
109}
110
111#[expect(dead_code)]
112#[derive(Debug, Serialize, Clone)]
113#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
114enum GeminiMimeType {
115    MimeTypeUnspecified,
116    TextPlain,
117    ApplicationJson,
118}
119
120#[derive(Debug, Serialize)]
121struct GeminiSystemInstruction {
122    parts: Vec<GeminiRequestPart>,
123}
124
125#[derive(Debug, Serialize)]
126struct GeminiContent {
127    role: String,
128    parts: Vec<GeminiRequestPart>,
129}
130
131// Enum for parts used ONLY in requests
132#[derive(Debug, Serialize)]
133#[serde(untagged)]
134enum GeminiRequestPart {
135    Text {
136        text: String,
137    },
138    #[serde(rename = "inlineData")]
139    InlineData {
140        #[serde(rename = "inlineData")]
141        inline_data: GeminiBlob,
142    },
143    #[serde(rename = "fileData")]
144    FileData {
145        #[serde(rename = "fileData")]
146        file_data: GeminiFileData,
147    },
148    #[serde(rename = "functionCall")]
149    FunctionCall {
150        #[serde(rename = "functionCall")]
151        function_call: GeminiFunctionCall, // Used for model turns in history
152        #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
153        thought_signature: Option<String>,
154    },
155    #[serde(rename = "functionResponse")]
156    FunctionResponse {
157        #[serde(rename = "functionResponse")]
158        function_response: GeminiFunctionResponse, // Used for function/tool turns
159    },
160}
161
162// Enum for parts received ONLY in responses
163#[derive(Debug, Deserialize)]
164#[serde(untagged)]
165enum GeminiResponsePartData {
166    Text {
167        text: String,
168    },
169    #[serde(rename = "inlineData")]
170    InlineData {
171        #[serde(rename = "inlineData")]
172        inline_data: GeminiBlob,
173    },
174    #[serde(rename = "functionCall")]
175    FunctionCall {
176        #[serde(rename = "functionCall")]
177        function_call: GeminiFunctionCall,
178    },
179    #[serde(rename = "fileData")]
180    FileData {
181        #[serde(rename = "fileData")]
182        file_data: GeminiFileData,
183    },
184    #[serde(rename = "executableCode")]
185    ExecutableCode {
186        #[serde(rename = "executableCode")]
187        executable_code: GeminiExecutableCode,
188    },
189    // Add other variants back here if needed
190}
191
192// 2. Change GeminiResponsePart to a struct
193#[derive(Debug, Deserialize)]
194struct GeminiResponsePart {
195    #[serde(default)] // Defaults to false if missing
196    thought: bool,
197    #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
198    thought_signature: Option<String>,
199
200    #[serde(flatten)] // Look for data fields directly in this struct's JSON
201    data: GeminiResponsePartData,
202}
203
204#[derive(Debug, Serialize, Deserialize)]
205struct GeminiFunctionCall {
206    name: String,
207    args: Value,
208}
209
210#[derive(Debug, Serialize, PartialEq)]
211struct GeminiTool {
212    #[serde(rename = "functionDeclarations")]
213    function_declarations: Vec<GeminiFunctionDeclaration>,
214}
215
216#[derive(Debug, Serialize, PartialEq)]
217struct GeminiFunctionDeclaration {
218    name: String,
219    description: String,
220    parameters: GeminiParameterSchema,
221}
222
223#[derive(Debug, Serialize, PartialEq)]
224struct GeminiParameterSchema {
225    #[serde(rename = "type")]
226    schema_type: String, // Typically "object"
227    properties: serde_json::Map<String, Value>,
228    required: Vec<String>,
229}
230
231#[derive(Debug, Deserialize)]
232struct GeminiResponse {
233    #[serde(rename = "candidates")]
234    #[serde(skip_serializing_if = "Option::is_none")]
235    candidates: Option<Vec<GeminiCandidate>>,
236    #[serde(rename = "promptFeedback")]
237    #[serde(skip_serializing_if = "Option::is_none")]
238    prompt_feedback: Option<GeminiPromptFeedback>,
239    #[serde(rename = "usageMetadata")]
240    #[serde(skip_serializing_if = "Option::is_none")]
241    usage_metadata: Option<GeminiUsageMetadata>,
242}
243
244#[derive(Debug, Deserialize)]
245struct GeminiCandidate {
246    content: GeminiContentResponse,
247    #[serde(rename = "finishReason")]
248    #[serde(skip_serializing_if = "Option::is_none")]
249    finish_reason: Option<GeminiFinishReason>,
250    #[serde(rename = "safetyRatings")]
251    #[serde(skip_serializing_if = "Option::is_none")]
252    safety_ratings: Option<Vec<GeminiSafetyRating>>,
253    #[serde(rename = "citationMetadata")]
254    #[serde(skip_serializing_if = "Option::is_none")]
255    citation_metadata: Option<GeminiCitationMetadata>,
256}
257
258#[derive(Debug, Deserialize)]
259#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
260enum GeminiFinishReason {
261    FinishReasonUnspecified,
262    Stop,
263    MaxTokens,
264    Safety,
265    Recitation,
266    Other,
267    #[serde(rename = "TOOL_CODE_ERROR")]
268    ToolCodeError,
269    #[serde(rename = "TOOL_EXECUTION_HALT")]
270    ToolExecutionHalt,
271    MalformedFunctionCall,
272}
273
274#[derive(Debug, Deserialize)]
275struct GeminiPromptFeedback {
276    #[serde(rename = "blockReason")]
277    #[serde(skip_serializing_if = "Option::is_none")]
278    block_reason: Option<GeminiBlockReason>,
279    #[serde(rename = "safetyRatings")]
280    #[serde(skip_serializing_if = "Option::is_none")]
281    safety_ratings: Option<Vec<GeminiSafetyRating>>,
282}
283
284#[derive(Debug, Deserialize)]
285#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
286enum GeminiBlockReason {
287    BlockReasonUnspecified,
288    Safety,
289    Other,
290}
291
292#[derive(Debug, Deserialize)]
293#[expect(dead_code)]
294struct GeminiSafetyRating {
295    category: GeminiHarmCategory,
296    probability: GeminiHarmProbability,
297    #[serde(default)] // Default to false if missing
298    blocked: bool,
299}
300
301#[derive(Debug, Deserialize, Serialize)] // Add Serialize for potential use in SafetySetting
302#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
303#[expect(clippy::enum_variant_names)]
304enum GeminiHarmCategory {
305    HarmCategoryUnspecified,
306    HarmCategoryDerogatory,
307    HarmCategoryToxicity,
308    HarmCategoryViolence,
309    HarmCategorySexual,
310    HarmCategoryMedical,
311    HarmCategoryDangerous,
312    HarmCategoryHarassment,
313    HarmCategoryHateSpeech,
314    HarmCategorySexuallyExplicit,
315    HarmCategoryDangerousContent,
316    HarmCategoryCivicIntegrity,
317}
318
319#[derive(Debug, Deserialize)]
320#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
321enum GeminiHarmProbability {
322    HarmProbabilityUnspecified,
323    Negligible,
324    Low,
325    Medium,
326    High,
327}
328
329#[expect(dead_code)]
330#[derive(Debug, Deserialize)]
331struct GeminiCitationMetadata {
332    #[serde(rename = "citationSources")]
333    #[serde(skip_serializing_if = "Option::is_none")]
334    citation_sources: Option<Vec<GeminiCitationSource>>,
335}
336
337#[expect(dead_code)]
338#[derive(Debug, Deserialize)]
339struct GeminiCitationSource {
340    #[serde(rename = "startIndex")]
341    #[serde(skip_serializing_if = "Option::is_none")]
342    start_index: Option<i32>,
343    #[serde(rename = "endIndex")]
344    #[serde(skip_serializing_if = "Option::is_none")]
345    end_index: Option<i32>,
346    #[serde(skip_serializing_if = "Option::is_none")]
347    uri: Option<String>,
348    #[serde(skip_serializing_if = "Option::is_none")]
349    license: Option<String>,
350}
351
352#[derive(Debug, Deserialize)]
353struct GeminiUsageMetadata {
354    #[serde(rename = "promptTokenCount")]
355    #[serde(skip_serializing_if = "Option::is_none")]
356    prompt: Option<i32>,
357    #[serde(rename = "candidatesTokenCount")]
358    #[serde(skip_serializing_if = "Option::is_none")]
359    candidates: Option<i32>,
360    #[serde(rename = "totalTokenCount")]
361    #[serde(skip_serializing_if = "Option::is_none")]
362    total: Option<i32>,
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366struct GeminiFunctionResponse {
367    name: String,
368    response: GeminiResponseContent,
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372struct GeminiResponseContent {
373    content: Value,
374}
375
376#[derive(Debug, Serialize, Deserialize)]
377struct GeminiExecutableCode {
378    language: String, // e.g., PYTHON
379    code: String,
380}
381
382#[derive(Debug, Deserialize)]
383#[expect(dead_code)]
384struct GeminiContentResponse {
385    role: String,
386    parts: Vec<GeminiResponsePart>,
387}
388
389fn map_gemini_usage(usage: &GeminiUsageMetadata) -> Option<TokenUsage> {
390    let prompt = usage.prompt.map(i64::from);
391    let candidates = usage.candidates.map(i64::from);
392    let total = usage.total.map(i64::from);
393
394    let input_tokens = prompt.or_else(|| match (total, candidates) {
395        (Some(total_tokens), Some(output_tokens)) if total_tokens >= output_tokens => {
396            Some(total_tokens - output_tokens)
397        }
398        _ => None,
399    })?;
400
401    let output_tokens = candidates.or_else(|| match (total, prompt) {
402        (Some(total_tokens), Some(input_tokens)) if total_tokens >= input_tokens => {
403            Some(total_tokens - input_tokens)
404        }
405        _ => None,
406    })?;
407
408    let total_tokens = total.or_else(|| input_tokens.checked_add(output_tokens))?;
409
410    Some(TokenUsage::new(
411        u32::try_from(input_tokens).ok()?,
412        u32::try_from(output_tokens).ok()?,
413        u32::try_from(total_tokens).ok()?,
414    ))
415}
416
417fn user_image_to_request_part(
418    image: &crate::app::conversation::ImageContent,
419) -> Result<GeminiRequestPart, ApiError> {
420    match &image.source {
421        ImageSource::DataUrl { data_url } => {
422            let Some((meta, payload)) = data_url.split_once(',') else {
423                return Err(ApiError::InvalidRequest {
424                    provider: "google".to_string(),
425                    details: "Image data URL is missing ',' separator".to_string(),
426                });
427            };
428
429            let Some(meta_body) = meta.strip_prefix("data:") else {
430                return Err(ApiError::InvalidRequest {
431                    provider: "google".to_string(),
432                    details: "Image data URL must start with 'data:'".to_string(),
433                });
434            };
435
436            if !meta_body
437                .split(';')
438                .any(|segment| segment.eq_ignore_ascii_case("base64"))
439            {
440                return Err(ApiError::InvalidRequest {
441                    provider: "google".to_string(),
442                    details: "Gemini image data URLs must be base64 encoded".to_string(),
443                });
444            }
445
446            let mime_type = meta_body
447                .split(';')
448                .next()
449                .filter(|value| !value.trim().is_empty())
450                .unwrap_or("application/octet-stream")
451                .to_string();
452
453            if !mime_type.starts_with("image/") {
454                return Err(ApiError::UnsupportedFeature {
455                    provider: "google".to_string(),
456                    feature: "image input mime type".to_string(),
457                    details: format!("Unsupported image media type '{}'", mime_type),
458                });
459            }
460
461            Ok(GeminiRequestPart::InlineData {
462                inline_data: GeminiBlob {
463                    mime_type,
464                    data: payload.to_string(),
465                },
466            })
467        }
468        ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
469            file_data: GeminiFileData {
470                mime_type: image.mime_type.clone(),
471                file_uri: url.clone(),
472            },
473        }),
474        ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
475            provider: "google".to_string(),
476            feature: "image input source".to_string(),
477            details: format!(
478                "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
479                relative_path
480            ),
481        }),
482    }
483}
484
485fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
486    let mut converted = Vec::new();
487
488    for msg in messages {
489        match &msg.data {
490            crate::app::conversation::MessageData::User { content, .. } => {
491                let mut parts = Vec::new();
492                for user_content in content {
493                    match user_content {
494                        UserContent::Text { text } => {
495                            parts.push(GeminiRequestPart::Text { text: text.clone() });
496                        }
497                        UserContent::Image { image } => {
498                            parts.push(user_image_to_request_part(image)?);
499                        }
500                        UserContent::CommandExecution {
501                            command,
502                            stdout,
503                            stderr,
504                            exit_code,
505                        } => {
506                            parts.push(GeminiRequestPart::Text {
507                                text: UserContent::format_command_execution_as_xml(
508                                    command, stdout, stderr, *exit_code,
509                                ),
510                            });
511                        }
512                    }
513                }
514
515                if !parts.is_empty() {
516                    converted.push(GeminiContent {
517                        role: "user".to_string(),
518                        parts,
519                    });
520                }
521            }
522            crate::app::conversation::MessageData::Assistant { content, .. } => {
523                let mut parts = Vec::new();
524                for assistant_content in content {
525                    match assistant_content {
526                        AssistantContent::Text { text } => {
527                            parts.push(GeminiRequestPart::Text { text: text.clone() });
528                        }
529                        AssistantContent::Image { image } => {
530                            match user_image_to_request_part(image) {
531                                Ok(part) => parts.push(part),
532                                Err(err) => {
533                                    debug!(
534                                        target: "gemini::convert_messages",
535                                        "Skipping unsupported assistant image block: {}",
536                                        err
537                                    );
538                                }
539                            }
540                        }
541                        AssistantContent::ToolCall {
542                            tool_call,
543                            thought_signature,
544                        } => parts.push(GeminiRequestPart::FunctionCall {
545                            function_call: GeminiFunctionCall {
546                                name: tool_call.name.clone(),
547                                args: tool_call.parameters.clone(),
548                            },
549                            thought_signature: thought_signature
550                                .as_ref()
551                                .map(|signature| signature.as_str().to_string()),
552                        }),
553                        AssistantContent::Thought { .. } => {
554                            // Gemini doesn't send thought blocks in requests
555                        }
556                    }
557                }
558
559                converted.push(GeminiContent {
560                    role: "model".to_string(),
561                    parts,
562                });
563            }
564            crate::app::conversation::MessageData::Tool {
565                tool_use_id,
566                result,
567                ..
568            } => {
569                let result_value = match result {
570                    ToolResult::Error(e) => Value::String(format!("Error: {e}")),
571                    _ => serde_json::to_value(result)
572                        .unwrap_or_else(|_| Value::String(result.llm_format())),
573                };
574
575                let parts = vec![GeminiRequestPart::FunctionResponse {
576                    function_response: GeminiFunctionResponse {
577                        name: tool_use_id.clone(),
578                        response: GeminiResponseContent {
579                            content: result_value,
580                        },
581                    },
582                }];
583
584                converted.push(GeminiContent {
585                    role: "function".to_string(),
586                    parts,
587                });
588            }
589        }
590    }
591
592    Ok(converted)
593}
594
595fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
596    let reference = schema.get("$ref").and_then(|v| v.as_str())?;
597    let path = reference.strip_prefix("#/")?;
598    let mut current = root;
599    for segment in path.split('/') {
600        current = current.get(segment)?;
601    }
602    Some(current)
603}
604
605fn infer_type_from_enum(values: &[Value]) -> Option<String> {
606    let mut has_string = false;
607    let mut has_number = false;
608    let mut has_bool = false;
609    let mut has_object = false;
610    let mut has_array = false;
611
612    for value in values {
613        match value {
614            Value::String(_) => has_string = true,
615            Value::Number(_) => has_number = true,
616            Value::Bool(_) => has_bool = true,
617            Value::Object(_) => has_object = true,
618            Value::Array(_) => has_array = true,
619            Value::Null => {}
620        }
621    }
622
623    let kind_count = u8::from(has_string)
624        + u8::from(has_number)
625        + u8::from(has_bool)
626        + u8::from(has_object)
627        + u8::from(has_array);
628
629    if kind_count != 1 {
630        return None;
631    }
632
633    if has_string {
634        Some("string".to_string())
635    } else if has_number {
636        Some("number".to_string())
637    } else if has_bool {
638        Some("boolean".to_string())
639    } else if has_object {
640        Some("object".to_string())
641    } else if has_array {
642        Some("array".to_string())
643    } else {
644        None
645    }
646}
647
648fn normalize_type(value: &Value) -> Value {
649    if let Some(type_str) = value.as_str() {
650        return Value::String(type_str.to_string());
651    }
652
653    if let Some(type_array) = value.as_array()
654        && let Some(primary_type) = type_array
655            .iter()
656            .find_map(|v| if v.is_null() { None } else { v.as_str() })
657    {
658        return Value::String(primary_type.to_string());
659    }
660
661    Value::String("string".to_string())
662}
663
664fn extract_enum_values(value: &Value) -> Vec<Value> {
665    let Some(obj) = value.as_object() else {
666        return Vec::new();
667    };
668
669    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
670        return enum_values
671            .iter()
672            .filter(|v| !v.is_null())
673            .cloned()
674            .collect();
675    }
676
677    if let Some(const_value) = obj.get("const") {
678        if const_value.is_null() {
679            return Vec::new();
680        }
681        return vec![const_value.clone()];
682    }
683
684    Vec::new()
685}
686
687fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
688    match properties.get_mut(key) {
689        None => {
690            properties.insert(key.to_string(), value.clone());
691        }
692        Some(existing) => {
693            if existing == value {
694                return;
695            }
696
697            let existing_values = extract_enum_values(existing);
698            let incoming_values = extract_enum_values(value);
699            if incoming_values.is_empty() && existing_values.is_empty() {
700                return;
701            }
702
703            let mut combined = existing_values;
704            for item in incoming_values {
705                if !combined.contains(&item) {
706                    combined.push(item);
707                }
708            }
709
710            if combined.is_empty() {
711                return;
712            }
713
714            if let Some(obj) = existing.as_object_mut() {
715                obj.remove("const");
716                obj.insert("enum".to_string(), Value::Array(combined.clone()));
717                if !obj.contains_key("type")
718                    && let Some(inferred) = infer_type_from_enum(&combined)
719                {
720                    obj.insert("type".to_string(), Value::String(inferred));
721                }
722            }
723        }
724    }
725}
726
727fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
728    let mut merged_props = serde_json::Map::new();
729    let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
730    let mut enum_values: Vec<Value> = Vec::new();
731    let mut type_candidates: Vec<String> = Vec::new();
732
733    for variant in variants {
734        let sanitized = sanitize_for_gemini(root, variant);
735
736        if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
737            type_candidates.push(schema_type.to_string());
738        }
739
740        if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
741            for (key, value) in props {
742                merge_property(&mut merged_props, key, value);
743            }
744        }
745
746        if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
747            let req_set: std::collections::BTreeSet<String> = req
748                .iter()
749                .filter_map(|item| item.as_str().map(|s| s.to_string()))
750                .collect();
751
752            required_intersection = match required_intersection.take() {
753                None => Some(req_set),
754                Some(existing) => Some(
755                    existing
756                        .intersection(&req_set)
757                        .cloned()
758                        .collect::<std::collections::BTreeSet<String>>(),
759                ),
760            };
761        }
762
763        if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
764            for value in values {
765                if value.is_null() {
766                    continue;
767                }
768                if !enum_values.contains(value) {
769                    enum_values.push(value.clone());
770                }
771            }
772        }
773    }
774
775    let schema_type = if !merged_props.is_empty() {
776        "object".to_string()
777    } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
778        inferred
779    } else if let Some(first) = type_candidates.first() {
780        first.clone()
781    } else {
782        "string".to_string()
783    };
784
785    let mut merged = serde_json::Map::new();
786    merged.insert("type".to_string(), Value::String(schema_type));
787
788    if !merged_props.is_empty() {
789        merged.insert("properties".to_string(), Value::Object(merged_props));
790    }
791
792    if let Some(required_set) = required_intersection
793        && !required_set.is_empty()
794    {
795        merged.insert(
796            "required".to_string(),
797            Value::Array(
798                required_set
799                    .into_iter()
800                    .map(Value::String)
801                    .collect::<Vec<_>>(),
802            ),
803        );
804    }
805
806    if !enum_values.is_empty() {
807        merged.insert("enum".to_string(), Value::Array(enum_values));
808    }
809
810    Value::Object(merged)
811}
812
813fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
814    if let Some(resolved) = resolve_ref(root, schema) {
815        return sanitize_for_gemini(root, resolved);
816    }
817
818    let Some(obj) = schema.as_object() else {
819        return schema.clone();
820    };
821
822    if let Some(union) = obj
823        .get("oneOf")
824        .or_else(|| obj.get("anyOf"))
825        .or_else(|| obj.get("allOf"))
826        .and_then(|v| v.as_array())
827    {
828        return merge_union_schemas(root, union);
829    }
830
831    let mut out = serde_json::Map::new();
832    for (key, value) in obj {
833        match key.as_str() {
834            "$ref"
835            | "$defs"
836            | "oneOf"
837            | "anyOf"
838            | "allOf"
839            | "const"
840            | "additionalProperties"
841            | "default"
842            | "examples"
843            | "title"
844            | "pattern"
845            | "minLength"
846            | "maxLength"
847            | "minimum"
848            | "maximum"
849            | "minItems"
850            | "maxItems"
851            | "uniqueItems"
852            | "deprecated" => {}
853            "type" => {
854                out.insert("type".to_string(), normalize_type(value));
855            }
856            "properties" => {
857                if let Some(props) = value.as_object() {
858                    let mut sanitized_props = serde_json::Map::new();
859                    for (prop_key, prop_value) in props {
860                        sanitized_props
861                            .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
862                    }
863                    out.insert("properties".to_string(), Value::Object(sanitized_props));
864                }
865            }
866            "items" => {
867                out.insert("items".to_string(), sanitize_for_gemini(root, value));
868            }
869            "enum" => {
870                let values = value
871                    .as_array()
872                    .map(|items| {
873                        items
874                            .iter()
875                            .filter(|v| !v.is_null())
876                            .cloned()
877                            .collect::<Vec<_>>()
878                    })
879                    .unwrap_or_default();
880                out.insert("enum".to_string(), Value::Array(values));
881            }
882            _ => {
883                out.insert(key.clone(), sanitize_for_gemini(root, value));
884            }
885        }
886    }
887
888    if let Some(const_value) = obj.get("const")
889        && !const_value.is_null()
890    {
891        out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
892        if !out.contains_key("type")
893            && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
894        {
895            out.insert("type".to_string(), Value::String(inferred));
896        }
897    }
898
899    if out.get("type") == Some(&Value::String("object".to_string()))
900        && !out.contains_key("properties")
901    {
902        out.insert(
903            "properties".to_string(),
904            Value::Object(serde_json::Map::new()),
905        );
906    }
907
908    if !out.contains_key("type") {
909        if out.contains_key("properties") {
910            out.insert("type".to_string(), Value::String("object".to_string()));
911        } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
912            && let Some(inferred) = infer_type_from_enum(enum_values)
913        {
914            out.insert("type".to_string(), Value::String(inferred));
915        }
916    }
917
918    Value::Object(out)
919}
920
921fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
922    if let Some(prop_map_orig) = property_value.as_object() {
923        let mut simplified_prop = prop_map_orig.clone();
924
925        // Remove 'additionalProperties' as Gemini doesn't support it
926        if simplified_prop.remove("additionalProperties").is_some() {
927            debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
928        }
929
930        // Simplify 'type' field (handle arrays like ["string", "null"])
931        if let Some(type_val) = simplified_prop.get_mut("type") {
932            if let Some(type_array) = type_val.as_array() {
933                if let Some(primary_type) = type_array
934                    .iter()
935                    .find_map(|v| if v.is_null() { None } else { v.as_str() })
936                {
937                    *type_val = serde_json::Value::String(primary_type.to_string());
938                } else {
939                    warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
940                    *type_val = serde_json::Value::String("string".to_string());
941                }
942            } else if !type_val.is_string() {
943                warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
944                *type_val = serde_json::Value::String("string".to_string());
945            }
946            // If it's already a simple string, do nothing.
947        }
948
949        // Fix integer format if necessary
950        if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
951            && let Some(format_val) = simplified_prop.get_mut("format")
952            && format_val.as_str() == Some("uint64")
953        {
954            *format_val = serde_json::Value::String("int64".to_string());
955            // Optionally remove minimum if Gemini doesn't support it with int64
956            // simplified_prop.remove("minimum");
957        }
958
959        // For string types, Gemini only supports 'enum' and 'date-time' formats
960        if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
961            let should_remove_format = simplified_prop
962                .get("format")
963                .and_then(|f| f.as_str())
964                .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
965
966            if should_remove_format
967                && let Some(format_val) = simplified_prop.remove("format")
968                && let Some(format_str) = format_val.as_str()
969            {
970                debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
971            }
972
973            // Also remove other string validation fields that might not be supported
974            if simplified_prop.remove("minLength").is_some() {
975                debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
976            }
977            if simplified_prop.remove("maxLength").is_some() {
978                debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
979            }
980            if simplified_prop.remove("pattern").is_some() {
981                debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
982            }
983        }
984
985        // Recursively simplify 'items' if this is an array type
986        if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
987            && let Some(items_val) = simplified_prop.get_mut("items")
988        {
989            *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
990        }
991
992        // Recursively simplify nested 'properties' if this is an object type
993        if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
994            && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
995        {
996            let simplified_nested_props: serde_json::Map<String, Value> = props
997                .iter()
998                .map(|(nested_key, nested_value)| {
999                    (
1000                        nested_key.clone(),
1001                        simplify_property_schema(
1002                            &format!("{key}.{nested_key}"),
1003                            tool_name,
1004                            nested_value,
1005                        ),
1006                    )
1007                })
1008                .collect();
1009            *props = simplified_nested_props;
1010        }
1011
1012        serde_json::Value::Object(simplified_prop)
1013    } else {
1014        warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
1015        property_value.clone() // Return original if not an object
1016    }
1017}
1018
1019fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
1020    let function_declarations = tools
1021        .into_iter()
1022        .map(|tool| {
1023            let root_schema = tool.input_schema.as_value();
1024            let summary = tool.input_schema.summary();
1025            let schema_type = if summary.schema_type.is_empty() {
1026                "object".to_string()
1027            } else {
1028                summary.schema_type
1029            };
1030
1031            // Simplify properties schema for Gemini using the helper function
1032            let simplified_properties = summary
1033                .properties
1034                .iter()
1035                .map(|(key, value)| {
1036                    let sanitized = sanitize_for_gemini(root_schema, value);
1037                    (
1038                        key.clone(),
1039                        simplify_property_schema(key, &tool.name, &sanitized),
1040                    )
1041                })
1042                .collect();
1043
1044            // Construct the parameters object using the specific struct
1045            let parameters = GeminiParameterSchema {
1046                schema_type,                       // Use schema_type field (usually "object")
1047                properties: simplified_properties, // Use simplified properties
1048                required: summary.required,        // Use required field
1049            };
1050
1051            GeminiFunctionDeclaration {
1052                name: tool.name,
1053                description: tool.description,
1054                parameters,
1055            }
1056        })
1057        .collect();
1058
1059    vec![GeminiTool {
1060        function_declarations,
1061    }]
1062}
1063
1064fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1065    // Log prompt feedback if present
1066    if let Some(feedback) = &response.prompt_feedback
1067        && let Some(reason) = &feedback.block_reason
1068    {
1069        let details = format!(
1070            "Prompt blocked due to {:?}. Safety ratings: {:?}",
1071            reason, feedback.safety_ratings
1072        );
1073        warn!(target: "gemini::convert_response", "{}", details);
1074        // Return the specific RequestBlocked error
1075        return Err(ApiError::RequestBlocked {
1076            provider: "google".to_string(), // Assuming "google" is the provider name
1077            details,
1078        });
1079    }
1080
1081    // Check candidates *after* checking for prompt blocking
1082    let candidates = if let Some(cands) = response.candidates {
1083        if cands.is_empty() {
1084            // If it was blocked, the previous check should have caught it.
1085            // So, this means no candidates were generated for other reasons.
1086            warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1087            // Use NoChoices error here
1088            return Err(ApiError::NoChoices {
1089                provider: "google".to_string(),
1090            });
1091        }
1092        cands // Return the non-empty vector
1093    } else {
1094        warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1095        // Use NoChoices error here as well
1096        return Err(ApiError::NoChoices {
1097            provider: "google".to_string(),
1098        });
1099    };
1100
1101    // For simplicity, still taking the first candidate. Multi-candidate handling could be added.
1102    // Access candidates safely since we've checked it's not None or empty.
1103    let candidate = &candidates[0];
1104
1105    // Log finish reason and safety ratings if present
1106    if let Some(reason) = &candidate.finish_reason {
1107        match reason {
1108            GeminiFinishReason::Stop => { /* Normal completion */ }
1109            GeminiFinishReason::MaxTokens => {
1110                warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1111            }
1112            GeminiFinishReason::Safety => {
1113                warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1114                // Consider returning an error or modifying the response based on safety ratings
1115            }
1116            GeminiFinishReason::Recitation => {
1117                warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1118            }
1119            GeminiFinishReason::MalformedFunctionCall => {
1120                warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1121            }
1122            _ => {
1123                info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1124            }
1125        }
1126    }
1127
1128    // Log usage metadata if present
1129    if let Some(usage) = &response.usage_metadata {
1130        debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1131               usage.prompt, usage.candidates, usage.total);
1132    }
1133    let usage = response.usage_metadata.as_ref().and_then(map_gemini_usage);
1134
1135    let content: Vec<AssistantContent> = candidate
1136        .content // GeminiContentResponse
1137        .parts   // Vec<GeminiResponsePart> (struct)
1138        .iter()
1139        .filter_map(|part| { // part is &GeminiResponsePart (struct)
1140            // Check if this is a thought part first
1141            if part.thought {
1142                debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1143                // For thought parts, extract text content and create a Thought block
1144                if let GeminiResponsePartData::Text { text } = &part.data {
1145                    Some(AssistantContent::Thought {
1146                        thought: ThoughtContent::Simple {
1147                            text: text.clone(),
1148                        },
1149                    })
1150                } else {
1151                    warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1152                    None
1153                }
1154            } else {
1155                // Regular (non-thought) content processing
1156                match &part.data {
1157                    GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1158                        text: text.clone(),
1159                    }),
1160                    GeminiResponsePartData::InlineData { inline_data } => {
1161                        warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1162                        Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1163                    }
1164                    GeminiResponsePartData::FunctionCall { function_call } => {
1165                        Some(AssistantContent::ToolCall {
1166                            tool_call: steer_tools::ToolCall {
1167                                id: uuid::Uuid::new_v4().to_string(), // Generate a synthetic ID
1168                                name: function_call.name.clone(),
1169                                parameters: function_call.args.clone(),
1170                            },
1171                            thought_signature: part
1172                                .thought_signature
1173                                .clone()
1174                                .map(ThoughtSignature::new),
1175                        })
1176                    }
1177                    GeminiResponsePartData::FileData { file_data } => {
1178                        warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1179                        Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1180                    }
1181                     GeminiResponsePartData::ExecutableCode { executable_code } => {
1182                         info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1183                              executable_code.language);
1184                         Some(AssistantContent::Text {
1185                             text: format!(
1186                                 "```{}
1187{}
1188```",
1189                                 executable_code.language.to_lowercase(),
1190                                 executable_code.code
1191                             ),
1192                         })
1193                     }
1194                }
1195            }
1196        })
1197        .collect();
1198
1199    Ok(CompletionResponse { content, usage })
1200}
1201
1202#[async_trait]
1203impl Provider for GeminiClient {
1204    fn name(&self) -> &'static str {
1205        "google"
1206    }
1207
1208    async fn complete(
1209        &self,
1210        model_id: &ModelId,
1211        messages: Vec<AppMessage>,
1212        system: Option<SystemContext>,
1213        tools: Option<Vec<ToolSchema>>,
1214        _call_options: Option<ModelParameters>,
1215        token: CancellationToken,
1216    ) -> Result<CompletionResponse, ApiError> {
1217        let model_name = &model_id.id; // Use the model ID string
1218        let url = format!(
1219            "{}/models/{}:generateContent?key={}",
1220            GEMINI_API_BASE, model_name, self.api_key
1221        );
1222
1223        let gemini_contents = convert_messages(messages)?;
1224
1225        let system_instruction = system
1226            .and_then(|context| context.render())
1227            .map(|instructions| GeminiSystemInstruction {
1228                parts: vec![GeminiRequestPart::Text { text: instructions }],
1229            });
1230
1231        let gemini_tools = tools.map(convert_tools);
1232
1233        // Derive generation config from call options, respecting model/catalog settings
1234        let (temperature, top_p, max_output_tokens) = {
1235            let opts = _call_options.as_ref();
1236            (
1237                opts.and_then(|o| o.temperature).or(Some(1.0)),
1238                opts.and_then(|o| o.top_p).or(Some(0.95)),
1239                opts.and_then(|o| o.max_output_tokens)
1240                    .map(|v| v as i32)
1241                    .or(Some(65536)),
1242            )
1243        };
1244        let thinking_config = _call_options
1245            .as_ref()
1246            .and_then(|o| o.thinking_config)
1247            .and_then(|tc| {
1248                if tc.enabled {
1249                    Some(GeminiThinkingConfig {
1250                        include_thoughts: tc.include_thoughts,
1251                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1252                    })
1253                } else {
1254                    None
1255                }
1256            });
1257
1258        let request = GeminiRequest {
1259            contents: gemini_contents,
1260            system_instruction,
1261            tools: gemini_tools,
1262            generation_config: Some(GeminiGenerationConfig {
1263                max_output_tokens,
1264                temperature,
1265                top_p,
1266                thinking_config,
1267                ..Default::default()
1268            }),
1269        };
1270
1271        let response = tokio::select! {
1272            biased;
1273            () = token.cancelled() => {
1274                debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1275                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1276            }
1277            res = self.client.post(&url).json(&request).send() => {
1278                res.map_err(ApiError::Network)?
1279            }
1280        };
1281        let status = response.status();
1282
1283        if status != StatusCode::OK {
1284            let error_text = response.text().await.map_err(ApiError::Network)?;
1285            error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1286            return Err(map_http_status_to_api_error(
1287                self.name(),
1288                status.as_u16(),
1289                error_text,
1290            ));
1291        }
1292
1293        let response_text = response.text().await.map_err(ApiError::Network)?;
1294
1295        match serde_json::from_str::<GeminiResponse>(&response_text) {
1296            Ok(gemini_response) => {
1297                convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1298                    provider: self.name().to_string(),
1299                    details: e.to_string(),
1300                })
1301            }
1302            Err(e) => {
1303                error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1304                Err(ApiError::ResponseParsingError {
1305                    provider: self.name().to_string(),
1306                    details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1307                })
1308            }
1309        }
1310    }
1311
1312    async fn stream_complete(
1313        &self,
1314        model_id: &ModelId,
1315        messages: Vec<AppMessage>,
1316        system: Option<SystemContext>,
1317        tools: Option<Vec<ToolSchema>>,
1318        _call_options: Option<ModelParameters>,
1319        token: CancellationToken,
1320    ) -> Result<CompletionStream, ApiError> {
1321        let model_name = &model_id.id;
1322        let url = format!(
1323            "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1324            GEMINI_API_BASE, model_name, self.api_key
1325        );
1326
1327        let gemini_contents = convert_messages(messages)?;
1328
1329        let system_instruction = system
1330            .and_then(|context| context.render())
1331            .map(|instructions| GeminiSystemInstruction {
1332                parts: vec![GeminiRequestPart::Text { text: instructions }],
1333            });
1334
1335        let gemini_tools = tools.map(convert_tools);
1336
1337        let (temperature, top_p, max_output_tokens) = {
1338            let opts = _call_options.as_ref();
1339            (
1340                opts.and_then(|o| o.temperature).or(Some(1.0)),
1341                opts.and_then(|o| o.top_p).or(Some(0.95)),
1342                opts.and_then(|o| o.max_output_tokens)
1343                    .map(|v| v as i32)
1344                    .or(Some(65536)),
1345            )
1346        };
1347        let thinking_config = _call_options
1348            .as_ref()
1349            .and_then(|o| o.thinking_config)
1350            .and_then(|tc| {
1351                if tc.enabled {
1352                    Some(GeminiThinkingConfig {
1353                        include_thoughts: tc.include_thoughts,
1354                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1355                    })
1356                } else {
1357                    None
1358                }
1359            });
1360
1361        let request = GeminiRequest {
1362            contents: gemini_contents,
1363            system_instruction,
1364            tools: gemini_tools,
1365            generation_config: Some(GeminiGenerationConfig {
1366                max_output_tokens,
1367                temperature,
1368                top_p,
1369                thinking_config,
1370                ..Default::default()
1371            }),
1372        };
1373
1374        let response = tokio::select! {
1375            biased;
1376            () = token.cancelled() => {
1377                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1378            }
1379            res = self.client.post(&url).json(&request).send() => {
1380                res.map_err(ApiError::Network)?
1381            }
1382        };
1383
1384        let status = response.status();
1385        if status != StatusCode::OK {
1386            let error_text = response.text().await.map_err(ApiError::Network)?;
1387            error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1388            return Err(map_http_status_to_api_error(
1389                self.name(),
1390                status.as_u16(),
1391                error_text,
1392            ));
1393        }
1394
1395        let byte_stream = response.bytes_stream();
1396        let sse_stream = parse_sse_stream(byte_stream);
1397
1398        Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1399    }
1400}
1401
1402impl GeminiClient {
1403    fn convert_gemini_stream(
1404        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1405        + Unpin
1406        + Send
1407        + 'static,
1408        token: CancellationToken,
1409    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1410        async_stream::stream! {
1411            let mut content: Vec<AssistantContent> = Vec::new();
1412            let mut latest_usage: Option<TokenUsage> = None;
1413            loop {
1414                if token.is_cancelled() {
1415                    yield StreamChunk::Error(StreamError::Cancelled);
1416                    break;
1417                }
1418
1419                let event_result = tokio::select! {
1420                    biased;
1421                    () = token.cancelled() => {
1422                        yield StreamChunk::Error(StreamError::Cancelled);
1423                        break;
1424                    }
1425                    event = sse_stream.next() => event
1426                };
1427
1428                let Some(event_result) = event_result else {
1429                    let content = std::mem::take(&mut content);
1430                    yield StreamChunk::MessageComplete(CompletionResponse {
1431                        content,
1432                        usage: latest_usage,
1433                    });
1434                    break;
1435                };
1436
1437                let event = match event_result {
1438                    Ok(e) => e,
1439                    Err(e) => {
1440                        yield StreamChunk::Error(StreamError::SseParse(e));
1441                        break;
1442                    }
1443                };
1444
1445                let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1446                    Ok(c) => c,
1447                    Err(e) => {
1448                        debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1449                        continue;
1450                    }
1451                };
1452
1453                if let Some(usage) = chunk.usage_metadata.as_ref()
1454                    && let Some(mapped) = map_gemini_usage(usage) {
1455                        latest_usage = Some(mapped);
1456                    }
1457
1458                if let Some(candidates) = chunk.candidates {
1459                    for candidate in candidates {
1460                        for part in candidate.content.parts {
1461                            let GeminiResponsePart {
1462                                thought,
1463                                thought_signature,
1464                                data,
1465                            } = part;
1466                            let thought_signature =
1467                                thought_signature.map(ThoughtSignature::new);
1468
1469                            if thought {
1470                                if let GeminiResponsePartData::Text { text } = data {
1471                                    match content.last_mut() {
1472                                        Some(AssistantContent::Thought {
1473                                            thought: ThoughtContent::Simple { text: buf },
1474                                        }) => buf.push_str(&text),
1475                                        _ => {
1476                                            content.push(AssistantContent::Thought {
1477                                                thought: ThoughtContent::Simple { text: text.clone() },
1478                                            });
1479                                        }
1480                                    }
1481                                    yield StreamChunk::ThinkingDelta(text);
1482                                }
1483                            } else {
1484                                match data {
1485                                    GeminiResponsePartData::Text { text } => {
1486                                        match content.last_mut() {
1487                                            Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1488                                            _ => {
1489                                                content.push(AssistantContent::Text { text: text.clone() });
1490                                            }
1491                                        }
1492                                        yield StreamChunk::TextDelta(text);
1493                                    }
1494                                    GeminiResponsePartData::FunctionCall { function_call } => {
1495                                        let id = uuid::Uuid::new_v4().to_string();
1496                                        content.push(AssistantContent::ToolCall {
1497                                            tool_call: steer_tools::ToolCall {
1498                                                id: id.clone(),
1499                                                name: function_call.name.clone(),
1500                                                parameters: function_call.args.clone(),
1501                                            },
1502                                            thought_signature,
1503                                        });
1504                                        yield StreamChunk::ToolUseStart {
1505                                            id: id.clone(),
1506                                            name: function_call.name,
1507                                        };
1508                                        yield StreamChunk::ToolUseInputDelta {
1509                                            id,
1510                                            delta: function_call.args.to_string(),
1511                                        };
1512                                    }
1513                                    _ => {}
1514                                }
1515                            }
1516                        }
1517                    }
1518                }
1519            }
1520        }
1521    }
1522}
1523
1524#[cfg(test)]
1525mod tests {
1526    use super::*;
1527    use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1528    use serde_json::json;
1529
1530    #[test]
1531    fn test_simplify_property_schema_removes_additional_properties() {
1532        let property_value = json!({
1533            "type": "object",
1534            "properties": {
1535                "name": {"type": "string"}
1536            },
1537            "additionalProperties": false
1538        });
1539
1540        let expected = json!({
1541            "type": "object",
1542            "properties": {
1543                "name": {"type": "string"}
1544            }
1545        });
1546
1547        let result = simplify_property_schema("testProp", "testTool", &property_value);
1548        assert_eq!(result, expected);
1549    }
1550
1551    #[test]
1552    fn test_simplify_property_schema_removes_unsupported_string_formats() {
1553        let property_value = json!({
1554            "type": "string",
1555            "format": "uri",
1556            "minLength": 1,
1557            "maxLength": 100,
1558            "pattern": "^https://"
1559        });
1560
1561        let expected = json!({
1562            "type": "string"
1563        });
1564
1565        let result = simplify_property_schema("urlProp", "testTool", &property_value);
1566        assert_eq!(result, expected);
1567    }
1568
1569    #[test]
1570    fn test_simplify_property_schema_keeps_supported_string_formats() {
1571        let property_value = json!({
1572            "type": "string",
1573            "format": "date-time"
1574        });
1575
1576        let expected = json!({
1577            "type": "string",
1578            "format": "date-time"
1579        });
1580
1581        let result = simplify_property_schema("dateProp", "testTool", &property_value);
1582        assert_eq!(result, expected);
1583    }
1584
1585    #[test]
1586    fn test_simplify_property_schema_handles_array_types() {
1587        let property_value = json!({
1588            "type": ["string", "null"],
1589            "format": "email"
1590        });
1591
1592        let expected = json!({
1593            "type": "string"
1594        });
1595
1596        let result = simplify_property_schema("emailProp", "testTool", &property_value);
1597        assert_eq!(result, expected);
1598    }
1599
1600    #[test]
1601    fn test_simplify_property_schema_recursively_handles_array_items() {
1602        let property_value = json!({
1603            "type": "array",
1604            "items": {
1605                "type": "object",
1606                "properties": {
1607                    "url": {
1608                        "type": "string",
1609                        "format": "uri"
1610                    }
1611                },
1612                "additionalProperties": false
1613            }
1614        });
1615
1616        let expected = json!({
1617            "type": "array",
1618            "items": {
1619                "type": "object",
1620                "properties": {
1621                    "url": {
1622                        "type": "string"
1623                    }
1624                }
1625            }
1626        });
1627
1628        let result = simplify_property_schema("linksProp", "testTool", &property_value);
1629        assert_eq!(result, expected);
1630    }
1631
1632    #[test]
1633    fn test_simplify_property_schema_recursively_handles_nested_objects() {
1634        let property_value = json!({
1635            "type": "object",
1636            "properties": {
1637                "nested": {
1638                    "type": "object",
1639                    "properties": {
1640                        "field": {
1641                            "type": "string",
1642                            "format": "hostname"
1643                        }
1644                    },
1645                    "additionalProperties": true
1646                }
1647            },
1648            "additionalProperties": false
1649        });
1650
1651        let expected = json!({
1652            "type": "object",
1653            "properties": {
1654                "nested": {
1655                    "type": "object",
1656                    "properties": {
1657                        "field": {
1658                            "type": "string"
1659                        }
1660                    }
1661                }
1662            }
1663        });
1664
1665        let result = simplify_property_schema("complexProp", "testTool", &property_value);
1666        assert_eq!(result, expected);
1667    }
1668
1669    #[test]
1670    fn test_simplify_property_schema_fixes_uint64_format() {
1671        let property_value = json!({
1672            "type": "integer",
1673            "format": "uint64"
1674        });
1675
1676        let expected = json!({
1677            "type": "integer",
1678            "format": "int64"
1679        });
1680
1681        let result = simplify_property_schema("idProp", "testTool", &property_value);
1682        assert_eq!(result, expected);
1683    }
1684
1685    #[test]
1686    fn test_convert_tools_integration() {
1687        use steer_tools::{InputSchema, ToolSchema};
1688
1689        let tool = ToolSchema {
1690            name: "create_issue".to_string(),
1691            display_name: "Create Issue".to_string(),
1692            description: "Create an issue".to_string(),
1693            input_schema: InputSchema::object(
1694                {
1695                    let mut props = serde_json::Map::new();
1696                    props.insert(
1697                        "title".to_string(),
1698                        json!({
1699                            "type": "string",
1700                            "minLength": 1
1701                        }),
1702                    );
1703                    props.insert(
1704                        "links".to_string(),
1705                        json!({
1706                            "type": "array",
1707                            "items": {
1708                                "type": "object",
1709                                "properties": {
1710                                    "url": {
1711                                        "type": "string",
1712                                        "format": "uri"
1713                                    }
1714                                },
1715                                "additionalProperties": false
1716                            }
1717                        }),
1718                    );
1719                    props
1720                },
1721                vec!["title".to_string()],
1722            ),
1723        };
1724
1725        let expected_tools = vec![GeminiTool {
1726            function_declarations: vec![GeminiFunctionDeclaration {
1727                name: "create_issue".to_string(),
1728                description: "Create an issue".to_string(),
1729                parameters: GeminiParameterSchema {
1730                    schema_type: "object".to_string(),
1731                    properties: {
1732                        let mut props = serde_json::Map::new();
1733                        props.insert(
1734                            "title".to_string(),
1735                            json!({
1736                                "type": "string"
1737                            }),
1738                        );
1739                        props.insert(
1740                            "links".to_string(),
1741                            json!({
1742                                "type": "array",
1743                                "items": {
1744                                    "type": "object",
1745                                    "properties": {
1746                                        "url": {
1747                                            "type": "string"
1748                                        }
1749                                    }
1750                                }
1751                            }),
1752                        );
1753                        props
1754                    },
1755                    required: vec!["title".to_string()],
1756                },
1757            }],
1758        }];
1759
1760        let result = convert_tools(vec![tool]);
1761        assert_eq!(result, expected_tools);
1762    }
1763
1764    #[test]
1765    fn test_convert_messages_includes_data_url_image_as_inline_data() {
1766        let messages = vec![Message {
1767            data: MessageData::User {
1768                content: vec![UserContent::Image {
1769                    image: ImageContent {
1770                        mime_type: "image/png".to_string(),
1771                        source: ImageSource::DataUrl {
1772                            data_url: "data:image/png;base64,Zm9v".to_string(),
1773                        },
1774                        width: None,
1775                        height: None,
1776                        bytes: None,
1777                        sha256: None,
1778                    },
1779                }],
1780            },
1781            timestamp: 1,
1782            id: "msg-image".to_string(),
1783            parent_message_id: None,
1784        }];
1785
1786        let converted = convert_messages(messages).expect("convert messages");
1787        assert_eq!(converted.len(), 1);
1788        let first = &converted[0];
1789        assert_eq!(first.role, "user");
1790        assert_eq!(first.parts.len(), 1);
1791
1792        let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1793        assert_eq!(json["inlineData"]["mimeType"], "image/png");
1794        assert_eq!(json["inlineData"]["data"], "Zm9v");
1795    }
1796
1797    #[test]
1798    fn test_convert_messages_rejects_session_file_image_source() {
1799        let messages = vec![Message {
1800            data: MessageData::User {
1801                content: vec![UserContent::Image {
1802                    image: ImageContent {
1803                        mime_type: "image/png".to_string(),
1804                        source: ImageSource::SessionFile {
1805                            relative_path: "session-1/image.png".to_string(),
1806                        },
1807                        width: None,
1808                        height: None,
1809                        bytes: None,
1810                        sha256: None,
1811                    },
1812                }],
1813            },
1814            timestamp: 1,
1815            id: "msg-image".to_string(),
1816            parent_message_id: None,
1817        }];
1818
1819        let err = convert_messages(messages).expect_err("expected unsupported feature");
1820        match err {
1821            ApiError::UnsupportedFeature {
1822                provider,
1823                feature,
1824                details,
1825            } => {
1826                assert_eq!(provider, "google");
1827                assert_eq!(feature, "image input source");
1828                assert!(details.contains("session file"));
1829            }
1830            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1831        }
1832    }
1833
1834    #[test]
1835    fn test_map_gemini_usage_derives_missing_counter() {
1836        let usage = GeminiUsageMetadata {
1837            prompt: None,
1838            candidates: Some(7),
1839            total: Some(19),
1840        };
1841
1842        assert_eq!(map_gemini_usage(&usage), Some(TokenUsage::new(12, 7, 19)));
1843    }
1844
1845    #[test]
1846    fn test_map_gemini_usage_returns_none_when_counters_not_derivable() {
1847        let usage = GeminiUsageMetadata {
1848            prompt: Some(9),
1849            candidates: None,
1850            total: None,
1851        };
1852
1853        assert_eq!(map_gemini_usage(&usage), None);
1854    }
1855
1856    #[test]
1857    fn test_convert_response_maps_usage_metadata() {
1858        let response = GeminiResponse {
1859            candidates: Some(vec![GeminiCandidate {
1860                content: GeminiContentResponse {
1861                    role: "model".to_string(),
1862                    parts: vec![GeminiResponsePart {
1863                        thought: false,
1864                        thought_signature: None,
1865                        data: GeminiResponsePartData::Text {
1866                            text: "Hello".to_string(),
1867                        },
1868                    }],
1869                },
1870                finish_reason: Some(GeminiFinishReason::Stop),
1871                safety_ratings: None,
1872                citation_metadata: None,
1873            }]),
1874            prompt_feedback: None,
1875            usage_metadata: Some(GeminiUsageMetadata {
1876                prompt: Some(11),
1877                candidates: Some(5),
1878                total: Some(16),
1879            }),
1880        };
1881
1882        let converted = convert_response(response).expect("response should convert");
1883
1884        assert_eq!(converted.usage, Some(TokenUsage::new(11, 5, 16)));
1885        assert!(matches!(
1886            converted.content.first(),
1887            Some(AssistantContent::Text { text }) if text == "Hello"
1888        ));
1889    }
1890
1891    #[tokio::test]
1892    async fn test_convert_gemini_stream_text_deltas() {
1893        use crate::api::provider::StreamChunk;
1894        use crate::api::sse::SseEvent;
1895        use futures::StreamExt;
1896        use futures::stream;
1897        use std::pin::pin;
1898        use tokio_util::sync::CancellationToken;
1899
1900        let events = vec![
1901            Ok(SseEvent {
1902                event_type: None,
1903                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1904                    .to_string(),
1905                id: None,
1906            }),
1907            Ok(SseEvent {
1908                event_type: None,
1909                data:
1910                    r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1911                        .to_string(),
1912                id: None,
1913            }),
1914        ];
1915
1916        let sse_stream = stream::iter(events);
1917        let token = CancellationToken::new();
1918        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1919
1920        let first_delta = stream.next().await.unwrap();
1921        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1922
1923        let second_delta = stream.next().await.unwrap();
1924        assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1925
1926        let complete = stream.next().await.unwrap();
1927        assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1928    }
1929
1930    #[tokio::test]
1931    async fn test_convert_gemini_stream_captures_final_usage() {
1932        use crate::api::provider::StreamChunk;
1933        use crate::api::sse::SseEvent;
1934        use futures::StreamExt;
1935        use futures::stream;
1936        use std::pin::pin;
1937        use tokio_util::sync::CancellationToken;
1938
1939        let events = vec![
1940            Ok(SseEvent {
1941                event_type: None,
1942                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1943                    .to_string(),
1944                id: None,
1945            }),
1946            Ok(SseEvent {
1947                event_type: None,
1948                data: r#"{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}}"#
1949                    .to_string(),
1950                id: None,
1951            }),
1952        ];
1953
1954        let sse_stream = stream::iter(events);
1955        let token = CancellationToken::new();
1956        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1957
1958        let first_delta = stream.next().await.unwrap();
1959        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1960
1961        let complete = stream.next().await.unwrap();
1962        if let StreamChunk::MessageComplete(response) = complete {
1963            assert_eq!(response.usage, Some(TokenUsage::new(10, 4, 14)));
1964            assert!(matches!(
1965                response.content.first(),
1966                Some(AssistantContent::Text { text }) if text == "Hello"
1967            ));
1968        } else {
1969            panic!("Expected MessageComplete");
1970        }
1971    }
1972
1973    #[tokio::test]
1974    async fn test_convert_gemini_stream_with_thinking() {
1975        use crate::api::provider::StreamChunk;
1976        use crate::api::sse::SseEvent;
1977        use futures::StreamExt;
1978        use futures::stream;
1979        use std::pin::pin;
1980        use tokio_util::sync::CancellationToken;
1981
1982        let events = vec![
1983            Ok(SseEvent {
1984                event_type: None,
1985                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1986                id: None,
1987            }),
1988            Ok(SseEvent {
1989                event_type: None,
1990                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1991                id: None,
1992            }),
1993        ];
1994
1995        let sse_stream = stream::iter(events);
1996        let token = CancellationToken::new();
1997        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1998
1999        let thinking_delta = stream.next().await.unwrap();
2000        assert!(
2001            matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
2002        );
2003
2004        let text_delta = stream.next().await.unwrap();
2005        assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
2006
2007        let complete = stream.next().await.unwrap();
2008        if let StreamChunk::MessageComplete(response) = complete {
2009            assert_eq!(response.content.len(), 2);
2010            assert!(matches!(
2011                &response.content[0],
2012                AssistantContent::Thought { .. }
2013            ));
2014            assert!(matches!(
2015                &response.content[1],
2016                AssistantContent::Text { .. }
2017            ));
2018        } else {
2019            panic!("Expected MessageComplete");
2020        }
2021    }
2022
2023    #[tokio::test]
2024    async fn test_convert_gemini_stream_with_function_call() {
2025        use crate::api::provider::StreamChunk;
2026        use crate::api::sse::SseEvent;
2027        use futures::StreamExt;
2028        use futures::stream;
2029        use std::pin::pin;
2030        use tokio_util::sync::CancellationToken;
2031
2032        let events = vec![
2033            Ok(SseEvent {
2034                event_type: None,
2035                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
2036                id: None,
2037            }),
2038        ];
2039
2040        let sse_stream = stream::iter(events);
2041        let token = CancellationToken::new();
2042        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2043
2044        let tool_start = stream.next().await.unwrap();
2045        assert!(
2046            matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
2047        );
2048
2049        let tool_input = stream.next().await.unwrap();
2050        assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
2051
2052        let complete = stream.next().await.unwrap();
2053        if let StreamChunk::MessageComplete(response) = complete {
2054            assert_eq!(response.content.len(), 1);
2055            if let AssistantContent::ToolCall {
2056                tool_call,
2057                thought_signature,
2058            } = &response.content[0]
2059            {
2060                assert_eq!(tool_call.name, "get_weather");
2061                assert_eq!(
2062                    thought_signature.as_ref().map(|sig| sig.as_str()),
2063                    Some("sig_123")
2064                );
2065            } else {
2066                panic!("Expected ToolCall");
2067            }
2068        } else {
2069            panic!("Expected MessageComplete");
2070        }
2071    }
2072
2073    #[tokio::test]
2074    async fn test_convert_gemini_stream_cancellation() {
2075        use crate::api::error::StreamError;
2076        use crate::api::provider::StreamChunk;
2077        use crate::api::sse::SseEvent;
2078        use futures::StreamExt;
2079        use futures::stream;
2080        use std::pin::pin;
2081        use tokio_util::sync::CancellationToken;
2082
2083        let events = vec![Ok(SseEvent {
2084            event_type: None,
2085            data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
2086                .to_string(),
2087            id: None,
2088        })];
2089
2090        let sse_stream = stream::iter(events);
2091        let token = CancellationToken::new();
2092        token.cancel();
2093
2094        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2095
2096        let cancelled = stream.next().await.unwrap();
2097        assert!(matches!(
2098            cancelled,
2099            StreamChunk::Error(StreamError::Cancelled)
2100        ));
2101    }
2102
2103    #[tokio::test]
2104    #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2105    async fn test_stream_complete_real_api() {
2106        use crate::api::Provider;
2107        use crate::api::provider::StreamChunk;
2108        use crate::app::conversation::{Message, MessageData, UserContent};
2109        use futures::StreamExt;
2110        use tokio_util::sync::CancellationToken;
2111
2112        dotenvy::dotenv().ok();
2113        let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2114        let client = GeminiClient::new(api_key);
2115
2116        let message = Message {
2117            data: MessageData::User {
2118                content: vec![UserContent::Text {
2119                    text: "Say exactly: Hello".to_string(),
2120                }],
2121            },
2122            timestamp: chrono::Utc::now().timestamp_millis() as u64,
2123            id: "test-msg".to_string(),
2124            parent_message_id: None,
2125        };
2126
2127        let model_id = ModelId::new(
2128            crate::config::provider::google(),
2129            "gemini-2.5-flash-preview-04-17",
2130        );
2131        let token = CancellationToken::new();
2132
2133        let mut stream = client
2134            .stream_complete(&model_id, vec![message], None, None, None, token)
2135            .await
2136            .expect("stream_complete should succeed");
2137
2138        let mut got_text_delta = false;
2139        let mut got_message_complete = false;
2140        let mut accumulated_text = String::new();
2141
2142        while let Some(chunk) = stream.next().await {
2143            match chunk {
2144                StreamChunk::TextDelta(text) => {
2145                    got_text_delta = true;
2146                    accumulated_text.push_str(&text);
2147                }
2148                StreamChunk::MessageComplete(response) => {
2149                    got_message_complete = true;
2150                    assert!(!response.content.is_empty());
2151                }
2152                StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2153                _ => {}
2154            }
2155        }
2156
2157        assert!(got_text_delta, "Should receive at least one TextDelta");
2158        assert!(
2159            got_message_complete,
2160            "Should receive MessageComplete at the end"
2161        );
2162        assert!(
2163            accumulated_text.to_lowercase().contains("hello"),
2164            "Response should contain 'hello', got: {accumulated_text}"
2165        );
2166    }
2167}