Skip to main content

steer_core/api/gemini/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11    CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::app::SystemContext;
15use crate::app::conversation::{
16    AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
17    ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
23
24#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone for potential future use
25struct GeminiBlob {
26    #[serde(rename = "mimeType")]
27    mime_type: String,
28    data: String, // Assuming base64 encoded data
29}
30
31#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
32struct GeminiFileData {
33    #[serde(rename = "mimeType")]
34    mime_type: String,
35    #[serde(rename = "fileUri")]
36    file_uri: String,
37}
38
39#[expect(dead_code)]
40#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
41struct GeminiCodeExecutionResult {
42    outcome: String, // e.g., "OK", "ERROR"
43                     // Potentially add output field later if needed
44}
45
46pub struct GeminiClient {
47    api_key: String,
48    client: HttpClient,
49}
50
51impl GeminiClient {
52    pub fn new(api_key: impl Into<String>) -> Self {
53        Self {
54            api_key: api_key.into(),
55            client: HttpClient::new(),
56        }
57    }
58}
59
60#[derive(Debug, Serialize)]
61struct GeminiRequest {
62    contents: Vec<GeminiContent>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    #[serde(rename = "systemInstruction")]
65    system_instruction: Option<GeminiSystemInstruction>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    tools: Option<Vec<GeminiTool>>,
68    #[serde(skip_serializing_if = "Option::is_none")]
69    #[serde(rename = "generationConfig")]
70    generation_config: Option<GeminiGenerationConfig>,
71}
72
73#[derive(Debug, Serialize, Default, Clone)]
74struct GeminiGenerationConfig {
75    #[serde(skip_serializing_if = "Option::is_none")]
76    #[serde(rename = "stopSequences")]
77    stop_sequences: Option<Vec<String>>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    #[serde(rename = "responseMimeType")]
80    response_mime_type: Option<GeminiMimeType>,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    #[serde(rename = "candidateCount")]
83    candidate_count: Option<i32>,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    #[serde(rename = "maxOutputTokens")]
86    max_output_tokens: Option<i32>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    temperature: Option<f32>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    #[serde(rename = "topP")]
91    top_p: Option<f32>,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    #[serde(rename = "topK")]
94    top_k: Option<i32>,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    #[serde(rename = "thinkingConfig")]
97    thinking_config: Option<GeminiThinkingConfig>,
98}
99
100#[derive(Debug, Serialize, Default, Clone)]
101struct GeminiThinkingConfig {
102    #[serde(skip_serializing_if = "Option::is_none")]
103    #[serde(rename = "includeThoughts")]
104    include_thoughts: Option<bool>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    #[serde(rename = "thinkingBudget")]
107    thinking_budget: Option<i32>,
108}
109
110#[expect(dead_code)]
111#[derive(Debug, Serialize, Clone)]
112#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
113enum GeminiMimeType {
114    MimeTypeUnspecified,
115    TextPlain,
116    ApplicationJson,
117}
118
119#[derive(Debug, Serialize)]
120struct GeminiSystemInstruction {
121    parts: Vec<GeminiRequestPart>,
122}
123
124#[derive(Debug, Serialize)]
125struct GeminiContent {
126    role: String,
127    parts: Vec<GeminiRequestPart>,
128}
129
130// Enum for parts used ONLY in requests
131#[derive(Debug, Serialize)]
132#[serde(untagged)]
133enum GeminiRequestPart {
134    Text {
135        text: String,
136    },
137    #[serde(rename = "inlineData")]
138    InlineData {
139        #[serde(rename = "inlineData")]
140        inline_data: GeminiBlob,
141    },
142    #[serde(rename = "fileData")]
143    FileData {
144        #[serde(rename = "fileData")]
145        file_data: GeminiFileData,
146    },
147    #[serde(rename = "functionCall")]
148    FunctionCall {
149        #[serde(rename = "functionCall")]
150        function_call: GeminiFunctionCall, // Used for model turns in history
151        #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
152        thought_signature: Option<String>,
153    },
154    #[serde(rename = "functionResponse")]
155    FunctionResponse {
156        #[serde(rename = "functionResponse")]
157        function_response: GeminiFunctionResponse, // Used for function/tool turns
158    },
159}
160
161// Enum for parts received ONLY in responses
162#[derive(Debug, Deserialize)]
163#[serde(untagged)]
164enum GeminiResponsePartData {
165    Text {
166        text: String,
167    },
168    #[serde(rename = "inlineData")]
169    InlineData {
170        #[serde(rename = "inlineData")]
171        inline_data: GeminiBlob,
172    },
173    #[serde(rename = "functionCall")]
174    FunctionCall {
175        #[serde(rename = "functionCall")]
176        function_call: GeminiFunctionCall,
177    },
178    #[serde(rename = "fileData")]
179    FileData {
180        #[serde(rename = "fileData")]
181        file_data: GeminiFileData,
182    },
183    #[serde(rename = "executableCode")]
184    ExecutableCode {
185        #[serde(rename = "executableCode")]
186        executable_code: GeminiExecutableCode,
187    },
188    // Add other variants back here if needed
189}
190
191// 2. Change GeminiResponsePart to a struct
192#[derive(Debug, Deserialize)]
193struct GeminiResponsePart {
194    #[serde(default)] // Defaults to false if missing
195    thought: bool,
196    #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
197    thought_signature: Option<String>,
198
199    #[serde(flatten)] // Look for data fields directly in this struct's JSON
200    data: GeminiResponsePartData,
201}
202
203#[derive(Debug, Serialize, Deserialize)]
204struct GeminiFunctionCall {
205    name: String,
206    args: Value,
207}
208
209#[derive(Debug, Serialize, PartialEq)]
210struct GeminiTool {
211    #[serde(rename = "functionDeclarations")]
212    function_declarations: Vec<GeminiFunctionDeclaration>,
213}
214
215#[derive(Debug, Serialize, PartialEq)]
216struct GeminiFunctionDeclaration {
217    name: String,
218    description: String,
219    parameters: GeminiParameterSchema,
220}
221
222#[derive(Debug, Serialize, PartialEq)]
223struct GeminiParameterSchema {
224    #[serde(rename = "type")]
225    schema_type: String, // Typically "object"
226    properties: serde_json::Map<String, Value>,
227    required: Vec<String>,
228}
229
230#[derive(Debug, Deserialize)]
231struct GeminiResponse {
232    #[serde(rename = "candidates")]
233    #[serde(skip_serializing_if = "Option::is_none")]
234    candidates: Option<Vec<GeminiCandidate>>,
235    #[serde(rename = "promptFeedback")]
236    #[serde(skip_serializing_if = "Option::is_none")]
237    prompt_feedback: Option<GeminiPromptFeedback>,
238    #[serde(rename = "usageMetadata")]
239    #[serde(skip_serializing_if = "Option::is_none")]
240    usage_metadata: Option<GeminiUsageMetadata>,
241}
242
243#[derive(Debug, Deserialize)]
244struct GeminiCandidate {
245    content: GeminiContentResponse,
246    #[serde(rename = "finishReason")]
247    #[serde(skip_serializing_if = "Option::is_none")]
248    finish_reason: Option<GeminiFinishReason>,
249    #[serde(rename = "safetyRatings")]
250    #[serde(skip_serializing_if = "Option::is_none")]
251    safety_ratings: Option<Vec<GeminiSafetyRating>>,
252    #[serde(rename = "citationMetadata")]
253    #[serde(skip_serializing_if = "Option::is_none")]
254    citation_metadata: Option<GeminiCitationMetadata>,
255}
256
257#[derive(Debug, Deserialize)]
258#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
259enum GeminiFinishReason {
260    FinishReasonUnspecified,
261    Stop,
262    MaxTokens,
263    Safety,
264    Recitation,
265    Other,
266    #[serde(rename = "TOOL_CODE_ERROR")]
267    ToolCodeError,
268    #[serde(rename = "TOOL_EXECUTION_HALT")]
269    ToolExecutionHalt,
270    MalformedFunctionCall,
271}
272
273#[derive(Debug, Deserialize)]
274struct GeminiPromptFeedback {
275    #[serde(rename = "blockReason")]
276    #[serde(skip_serializing_if = "Option::is_none")]
277    block_reason: Option<GeminiBlockReason>,
278    #[serde(rename = "safetyRatings")]
279    #[serde(skip_serializing_if = "Option::is_none")]
280    safety_ratings: Option<Vec<GeminiSafetyRating>>,
281}
282
283#[derive(Debug, Deserialize)]
284#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
285enum GeminiBlockReason {
286    BlockReasonUnspecified,
287    Safety,
288    Other,
289}
290
291#[derive(Debug, Deserialize)]
292#[expect(dead_code)]
293struct GeminiSafetyRating {
294    category: GeminiHarmCategory,
295    probability: GeminiHarmProbability,
296    #[serde(default)] // Default to false if missing
297    blocked: bool,
298}
299
300#[derive(Debug, Deserialize, Serialize)] // Add Serialize for potential use in SafetySetting
301#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
302#[expect(clippy::enum_variant_names)]
303enum GeminiHarmCategory {
304    HarmCategoryUnspecified,
305    HarmCategoryDerogatory,
306    HarmCategoryToxicity,
307    HarmCategoryViolence,
308    HarmCategorySexual,
309    HarmCategoryMedical,
310    HarmCategoryDangerous,
311    HarmCategoryHarassment,
312    HarmCategoryHateSpeech,
313    HarmCategorySexuallyExplicit,
314    HarmCategoryDangerousContent,
315    HarmCategoryCivicIntegrity,
316}
317
318#[derive(Debug, Deserialize)]
319#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
320enum GeminiHarmProbability {
321    HarmProbabilityUnspecified,
322    Negligible,
323    Low,
324    Medium,
325    High,
326}
327
328#[expect(dead_code)]
329#[derive(Debug, Deserialize)]
330struct GeminiCitationMetadata {
331    #[serde(rename = "citationSources")]
332    #[serde(skip_serializing_if = "Option::is_none")]
333    citation_sources: Option<Vec<GeminiCitationSource>>,
334}
335
336#[expect(dead_code)]
337#[derive(Debug, Deserialize)]
338struct GeminiCitationSource {
339    #[serde(rename = "startIndex")]
340    #[serde(skip_serializing_if = "Option::is_none")]
341    start_index: Option<i32>,
342    #[serde(rename = "endIndex")]
343    #[serde(skip_serializing_if = "Option::is_none")]
344    end_index: Option<i32>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    uri: Option<String>,
347    #[serde(skip_serializing_if = "Option::is_none")]
348    license: Option<String>,
349}
350
351#[derive(Debug, Deserialize)]
352struct GeminiUsageMetadata {
353    #[serde(rename = "promptTokenCount")]
354    #[serde(skip_serializing_if = "Option::is_none")]
355    prompt: Option<i32>,
356    #[serde(rename = "candidatesTokenCount")]
357    #[serde(skip_serializing_if = "Option::is_none")]
358    candidates: Option<i32>,
359    #[serde(rename = "totalTokenCount")]
360    #[serde(skip_serializing_if = "Option::is_none")]
361    total: Option<i32>,
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365struct GeminiFunctionResponse {
366    name: String,
367    response: GeminiResponseContent,
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371struct GeminiResponseContent {
372    content: Value,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376struct GeminiExecutableCode {
377    language: String, // e.g., PYTHON
378    code: String,
379}
380
381#[derive(Debug, Deserialize)]
382#[expect(dead_code)]
383struct GeminiContentResponse {
384    role: String,
385    parts: Vec<GeminiResponsePart>,
386}
387
388fn map_gemini_usage(usage: &GeminiUsageMetadata) -> Option<TokenUsage> {
389    let prompt = usage.prompt.map(i64::from);
390    let candidates = usage.candidates.map(i64::from);
391    let total = usage.total.map(i64::from);
392
393    let input_tokens = prompt.or_else(|| match (total, candidates) {
394        (Some(total_tokens), Some(output_tokens)) if total_tokens >= output_tokens => {
395            Some(total_tokens - output_tokens)
396        }
397        _ => None,
398    })?;
399
400    let output_tokens = candidates.or_else(|| match (total, prompt) {
401        (Some(total_tokens), Some(input_tokens)) if total_tokens >= input_tokens => {
402            Some(total_tokens - input_tokens)
403        }
404        _ => None,
405    })?;
406
407    let total_tokens = total.or_else(|| input_tokens.checked_add(output_tokens))?;
408
409    Some(TokenUsage::new(
410        u32::try_from(input_tokens).ok()?,
411        u32::try_from(output_tokens).ok()?,
412        u32::try_from(total_tokens).ok()?,
413    ))
414}
415
416fn user_image_to_request_part(
417    image: &crate::app::conversation::ImageContent,
418) -> Result<GeminiRequestPart, ApiError> {
419    match &image.source {
420        ImageSource::DataUrl { data_url } => {
421            let Some((meta, payload)) = data_url.split_once(',') else {
422                return Err(ApiError::InvalidRequest {
423                    provider: "google".to_string(),
424                    details: "Image data URL is missing ',' separator".to_string(),
425                });
426            };
427
428            let Some(meta_body) = meta.strip_prefix("data:") else {
429                return Err(ApiError::InvalidRequest {
430                    provider: "google".to_string(),
431                    details: "Image data URL must start with 'data:'".to_string(),
432                });
433            };
434
435            if !meta_body
436                .split(';')
437                .any(|segment| segment.eq_ignore_ascii_case("base64"))
438            {
439                return Err(ApiError::InvalidRequest {
440                    provider: "google".to_string(),
441                    details: "Gemini image data URLs must be base64 encoded".to_string(),
442                });
443            }
444
445            let mime_type = meta_body
446                .split(';')
447                .next()
448                .filter(|value| !value.trim().is_empty())
449                .unwrap_or("application/octet-stream")
450                .to_string();
451
452            if !mime_type.starts_with("image/") {
453                return Err(ApiError::UnsupportedFeature {
454                    provider: "google".to_string(),
455                    feature: "image input mime type".to_string(),
456                    details: format!("Unsupported image media type '{}'", mime_type),
457                });
458            }
459
460            Ok(GeminiRequestPart::InlineData {
461                inline_data: GeminiBlob {
462                    mime_type,
463                    data: payload.to_string(),
464                },
465            })
466        }
467        ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
468            file_data: GeminiFileData {
469                mime_type: image.mime_type.clone(),
470                file_uri: url.clone(),
471            },
472        }),
473        ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
474            provider: "google".to_string(),
475            feature: "image input source".to_string(),
476            details: format!(
477                "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
478                relative_path
479            ),
480        }),
481    }
482}
483
484fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
485    let mut converted = Vec::new();
486
487    for msg in messages {
488        match &msg.data {
489            crate::app::conversation::MessageData::User { content, .. } => {
490                let mut parts = Vec::new();
491                for user_content in content {
492                    match user_content {
493                        UserContent::Text { text } => {
494                            parts.push(GeminiRequestPart::Text { text: text.clone() });
495                        }
496                        UserContent::Image { image } => {
497                            parts.push(user_image_to_request_part(image)?);
498                        }
499                        UserContent::CommandExecution {
500                            command,
501                            stdout,
502                            stderr,
503                            exit_code,
504                        } => {
505                            parts.push(GeminiRequestPart::Text {
506                                text: UserContent::format_command_execution_as_xml(
507                                    command, stdout, stderr, *exit_code,
508                                ),
509                            });
510                        }
511                    }
512                }
513
514                if !parts.is_empty() {
515                    converted.push(GeminiContent {
516                        role: "user".to_string(),
517                        parts,
518                    });
519                }
520            }
521            crate::app::conversation::MessageData::Assistant { content, .. } => {
522                let mut parts = Vec::new();
523                for assistant_content in content {
524                    match assistant_content {
525                        AssistantContent::Text { text } => {
526                            parts.push(GeminiRequestPart::Text { text: text.clone() });
527                        }
528                        AssistantContent::Image { image } => {
529                            match user_image_to_request_part(image) {
530                                Ok(part) => parts.push(part),
531                                Err(err) => {
532                                    debug!(
533                                        target: "gemini::convert_messages",
534                                        "Skipping unsupported assistant image block: {}",
535                                        err
536                                    );
537                                }
538                            }
539                        }
540                        AssistantContent::ToolCall {
541                            tool_call,
542                            thought_signature,
543                        } => parts.push(GeminiRequestPart::FunctionCall {
544                            function_call: GeminiFunctionCall {
545                                name: tool_call.name.clone(),
546                                args: tool_call.parameters.clone(),
547                            },
548                            thought_signature: thought_signature
549                                .as_ref()
550                                .map(|signature| signature.as_str().to_string()),
551                        }),
552                        AssistantContent::Thought { .. } => {
553                            // Gemini doesn't send thought blocks in requests
554                        }
555                    }
556                }
557
558                converted.push(GeminiContent {
559                    role: "model".to_string(),
560                    parts,
561                });
562            }
563            crate::app::conversation::MessageData::Tool {
564                tool_use_id,
565                result,
566                ..
567            } => {
568                let result_value = match result {
569                    ToolResult::Error(e) => Value::String(format!("Error: {e}")),
570                    _ => serde_json::to_value(result)
571                        .unwrap_or_else(|_| Value::String(result.llm_format())),
572                };
573
574                let parts = vec![GeminiRequestPart::FunctionResponse {
575                    function_response: GeminiFunctionResponse {
576                        name: tool_use_id.clone(),
577                        response: GeminiResponseContent {
578                            content: result_value,
579                        },
580                    },
581                }];
582
583                converted.push(GeminiContent {
584                    role: "function".to_string(),
585                    parts,
586                });
587            }
588        }
589    }
590
591    Ok(converted)
592}
593
594fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
595    let reference = schema.get("$ref").and_then(|v| v.as_str())?;
596    let path = reference.strip_prefix("#/")?;
597    let mut current = root;
598    for segment in path.split('/') {
599        current = current.get(segment)?;
600    }
601    Some(current)
602}
603
604fn infer_type_from_enum(values: &[Value]) -> Option<String> {
605    let mut has_string = false;
606    let mut has_number = false;
607    let mut has_bool = false;
608    let mut has_object = false;
609    let mut has_array = false;
610
611    for value in values {
612        match value {
613            Value::String(_) => has_string = true,
614            Value::Number(_) => has_number = true,
615            Value::Bool(_) => has_bool = true,
616            Value::Object(_) => has_object = true,
617            Value::Array(_) => has_array = true,
618            Value::Null => {}
619        }
620    }
621
622    let kind_count = u8::from(has_string)
623        + u8::from(has_number)
624        + u8::from(has_bool)
625        + u8::from(has_object)
626        + u8::from(has_array);
627
628    if kind_count != 1 {
629        return None;
630    }
631
632    if has_string {
633        Some("string".to_string())
634    } else if has_number {
635        Some("number".to_string())
636    } else if has_bool {
637        Some("boolean".to_string())
638    } else if has_object {
639        Some("object".to_string())
640    } else if has_array {
641        Some("array".to_string())
642    } else {
643        None
644    }
645}
646
647fn normalize_type(value: &Value) -> Value {
648    if let Some(type_str) = value.as_str() {
649        return Value::String(type_str.to_string());
650    }
651
652    if let Some(type_array) = value.as_array()
653        && let Some(primary_type) = type_array
654            .iter()
655            .find_map(|v| if v.is_null() { None } else { v.as_str() })
656    {
657        return Value::String(primary_type.to_string());
658    }
659
660    Value::String("string".to_string())
661}
662
663fn extract_enum_values(value: &Value) -> Vec<Value> {
664    let Some(obj) = value.as_object() else {
665        return Vec::new();
666    };
667
668    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
669        return enum_values
670            .iter()
671            .filter(|v| !v.is_null())
672            .cloned()
673            .collect();
674    }
675
676    if let Some(const_value) = obj.get("const") {
677        if const_value.is_null() {
678            return Vec::new();
679        }
680        return vec![const_value.clone()];
681    }
682
683    Vec::new()
684}
685
686fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
687    match properties.get_mut(key) {
688        None => {
689            properties.insert(key.to_string(), value.clone());
690        }
691        Some(existing) => {
692            if existing == value {
693                return;
694            }
695
696            let existing_values = extract_enum_values(existing);
697            let incoming_values = extract_enum_values(value);
698            if incoming_values.is_empty() && existing_values.is_empty() {
699                return;
700            }
701
702            let mut combined = existing_values;
703            for item in incoming_values {
704                if !combined.contains(&item) {
705                    combined.push(item);
706                }
707            }
708
709            if combined.is_empty() {
710                return;
711            }
712
713            if let Some(obj) = existing.as_object_mut() {
714                obj.remove("const");
715                obj.insert("enum".to_string(), Value::Array(combined.clone()));
716                if !obj.contains_key("type")
717                    && let Some(inferred) = infer_type_from_enum(&combined)
718                {
719                    obj.insert("type".to_string(), Value::String(inferred));
720                }
721            }
722        }
723    }
724}
725
726fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
727    let mut merged_props = serde_json::Map::new();
728    let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
729    let mut enum_values: Vec<Value> = Vec::new();
730    let mut type_candidates: Vec<String> = Vec::new();
731
732    for variant in variants {
733        let sanitized = sanitize_for_gemini(root, variant);
734
735        if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
736            type_candidates.push(schema_type.to_string());
737        }
738
739        if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
740            for (key, value) in props {
741                merge_property(&mut merged_props, key, value);
742            }
743        }
744
745        if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
746            let req_set: std::collections::BTreeSet<String> = req
747                .iter()
748                .filter_map(|item| item.as_str().map(|s| s.to_string()))
749                .collect();
750
751            required_intersection = match required_intersection.take() {
752                None => Some(req_set),
753                Some(existing) => Some(
754                    existing
755                        .intersection(&req_set)
756                        .cloned()
757                        .collect::<std::collections::BTreeSet<String>>(),
758                ),
759            };
760        }
761
762        if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
763            for value in values {
764                if value.is_null() {
765                    continue;
766                }
767                if !enum_values.contains(value) {
768                    enum_values.push(value.clone());
769                }
770            }
771        }
772    }
773
774    let schema_type = if !merged_props.is_empty() {
775        "object".to_string()
776    } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
777        inferred
778    } else if let Some(first) = type_candidates.first() {
779        first.clone()
780    } else {
781        "string".to_string()
782    };
783
784    let mut merged = serde_json::Map::new();
785    merged.insert("type".to_string(), Value::String(schema_type));
786
787    if !merged_props.is_empty() {
788        merged.insert("properties".to_string(), Value::Object(merged_props));
789    }
790
791    if let Some(required_set) = required_intersection
792        && !required_set.is_empty()
793    {
794        merged.insert(
795            "required".to_string(),
796            Value::Array(
797                required_set
798                    .into_iter()
799                    .map(Value::String)
800                    .collect::<Vec<_>>(),
801            ),
802        );
803    }
804
805    if !enum_values.is_empty() {
806        merged.insert("enum".to_string(), Value::Array(enum_values));
807    }
808
809    Value::Object(merged)
810}
811
812fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
813    if let Some(resolved) = resolve_ref(root, schema) {
814        return sanitize_for_gemini(root, resolved);
815    }
816
817    let Some(obj) = schema.as_object() else {
818        return schema.clone();
819    };
820
821    if let Some(union) = obj
822        .get("oneOf")
823        .or_else(|| obj.get("anyOf"))
824        .or_else(|| obj.get("allOf"))
825        .and_then(|v| v.as_array())
826    {
827        return merge_union_schemas(root, union);
828    }
829
830    let mut out = serde_json::Map::new();
831    for (key, value) in obj {
832        match key.as_str() {
833            "$ref"
834            | "$defs"
835            | "oneOf"
836            | "anyOf"
837            | "allOf"
838            | "const"
839            | "additionalProperties"
840            | "default"
841            | "examples"
842            | "title"
843            | "pattern"
844            | "minLength"
845            | "maxLength"
846            | "minimum"
847            | "maximum"
848            | "minItems"
849            | "maxItems"
850            | "uniqueItems"
851            | "deprecated" => {}
852            "type" => {
853                out.insert("type".to_string(), normalize_type(value));
854            }
855            "properties" => {
856                if let Some(props) = value.as_object() {
857                    let mut sanitized_props = serde_json::Map::new();
858                    for (prop_key, prop_value) in props {
859                        sanitized_props
860                            .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
861                    }
862                    out.insert("properties".to_string(), Value::Object(sanitized_props));
863                }
864            }
865            "items" => {
866                out.insert("items".to_string(), sanitize_for_gemini(root, value));
867            }
868            "enum" => {
869                let values = value
870                    .as_array()
871                    .map(|items| {
872                        items
873                            .iter()
874                            .filter(|v| !v.is_null())
875                            .cloned()
876                            .collect::<Vec<_>>()
877                    })
878                    .unwrap_or_default();
879                out.insert("enum".to_string(), Value::Array(values));
880            }
881            _ => {
882                out.insert(key.clone(), sanitize_for_gemini(root, value));
883            }
884        }
885    }
886
887    if let Some(const_value) = obj.get("const")
888        && !const_value.is_null()
889    {
890        out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
891        if !out.contains_key("type")
892            && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
893        {
894            out.insert("type".to_string(), Value::String(inferred));
895        }
896    }
897
898    if out.get("type") == Some(&Value::String("object".to_string()))
899        && !out.contains_key("properties")
900    {
901        out.insert(
902            "properties".to_string(),
903            Value::Object(serde_json::Map::new()),
904        );
905    }
906
907    if !out.contains_key("type") {
908        if out.contains_key("properties") {
909            out.insert("type".to_string(), Value::String("object".to_string()));
910        } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
911            && let Some(inferred) = infer_type_from_enum(enum_values)
912        {
913            out.insert("type".to_string(), Value::String(inferred));
914        }
915    }
916
917    Value::Object(out)
918}
919
920fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
921    if let Some(prop_map_orig) = property_value.as_object() {
922        let mut simplified_prop = prop_map_orig.clone();
923
924        // Remove 'additionalProperties' as Gemini doesn't support it
925        if simplified_prop.remove("additionalProperties").is_some() {
926            debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
927        }
928
929        // Simplify 'type' field (handle arrays like ["string", "null"])
930        if let Some(type_val) = simplified_prop.get_mut("type") {
931            if let Some(type_array) = type_val.as_array() {
932                if let Some(primary_type) = type_array
933                    .iter()
934                    .find_map(|v| if v.is_null() { None } else { v.as_str() })
935                {
936                    *type_val = serde_json::Value::String(primary_type.to_string());
937                } else {
938                    warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
939                    *type_val = serde_json::Value::String("string".to_string());
940                }
941            } else if !type_val.is_string() {
942                warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
943                *type_val = serde_json::Value::String("string".to_string());
944            }
945            // If it's already a simple string, do nothing.
946        }
947
948        // Fix integer format if necessary
949        if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
950            && let Some(format_val) = simplified_prop.get_mut("format")
951            && format_val.as_str() == Some("uint64")
952        {
953            *format_val = serde_json::Value::String("int64".to_string());
954            // Optionally remove minimum if Gemini doesn't support it with int64
955            // simplified_prop.remove("minimum");
956        }
957
958        // For string types, Gemini only supports 'enum' and 'date-time' formats
959        if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
960            let should_remove_format = simplified_prop
961                .get("format")
962                .and_then(|f| f.as_str())
963                .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
964
965            if should_remove_format
966                && let Some(format_val) = simplified_prop.remove("format")
967                && let Some(format_str) = format_val.as_str()
968            {
969                debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
970            }
971
972            // Also remove other string validation fields that might not be supported
973            if simplified_prop.remove("minLength").is_some() {
974                debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
975            }
976            if simplified_prop.remove("maxLength").is_some() {
977                debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
978            }
979            if simplified_prop.remove("pattern").is_some() {
980                debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
981            }
982        }
983
984        // Recursively simplify 'items' if this is an array type
985        if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
986            && let Some(items_val) = simplified_prop.get_mut("items")
987        {
988            *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
989        }
990
991        // Recursively simplify nested 'properties' if this is an object type
992        if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
993            && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
994        {
995            let simplified_nested_props: serde_json::Map<String, Value> = props
996                .iter()
997                .map(|(nested_key, nested_value)| {
998                    (
999                        nested_key.clone(),
1000                        simplify_property_schema(
1001                            &format!("{key}.{nested_key}"),
1002                            tool_name,
1003                            nested_value,
1004                        ),
1005                    )
1006                })
1007                .collect();
1008            *props = simplified_nested_props;
1009        }
1010
1011        serde_json::Value::Object(simplified_prop)
1012    } else {
1013        warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
1014        property_value.clone() // Return original if not an object
1015    }
1016}
1017
1018fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
1019    let function_declarations = tools
1020        .into_iter()
1021        .map(|tool| {
1022            let root_schema = tool.input_schema.as_value();
1023            let summary = tool.input_schema.summary();
1024            let schema_type = if summary.schema_type.is_empty() {
1025                "object".to_string()
1026            } else {
1027                summary.schema_type
1028            };
1029
1030            // Simplify properties schema for Gemini using the helper function
1031            let simplified_properties = summary
1032                .properties
1033                .iter()
1034                .map(|(key, value)| {
1035                    let sanitized = sanitize_for_gemini(root_schema, value);
1036                    (
1037                        key.clone(),
1038                        simplify_property_schema(key, &tool.name, &sanitized),
1039                    )
1040                })
1041                .collect();
1042
1043            // Construct the parameters object using the specific struct
1044            let parameters = GeminiParameterSchema {
1045                schema_type,                       // Use schema_type field (usually "object")
1046                properties: simplified_properties, // Use simplified properties
1047                required: summary.required,        // Use required field
1048            };
1049
1050            GeminiFunctionDeclaration {
1051                name: tool.name,
1052                description: tool.description,
1053                parameters,
1054            }
1055        })
1056        .collect();
1057
1058    vec![GeminiTool {
1059        function_declarations,
1060    }]
1061}
1062
1063fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1064    // Log prompt feedback if present
1065    if let Some(feedback) = &response.prompt_feedback
1066        && let Some(reason) = &feedback.block_reason
1067    {
1068        let details = format!(
1069            "Prompt blocked due to {:?}. Safety ratings: {:?}",
1070            reason, feedback.safety_ratings
1071        );
1072        warn!(target: "gemini::convert_response", "{}", details);
1073        // Return the specific RequestBlocked error
1074        return Err(ApiError::RequestBlocked {
1075            provider: "google".to_string(), // Assuming "google" is the provider name
1076            details,
1077        });
1078    }
1079
1080    // Check candidates *after* checking for prompt blocking
1081    let candidates = if let Some(cands) = response.candidates {
1082        if cands.is_empty() {
1083            // If it was blocked, the previous check should have caught it.
1084            // So, this means no candidates were generated for other reasons.
1085            warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1086            // Use NoChoices error here
1087            return Err(ApiError::NoChoices {
1088                provider: "google".to_string(),
1089            });
1090        }
1091        cands // Return the non-empty vector
1092    } else {
1093        warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1094        // Use NoChoices error here as well
1095        return Err(ApiError::NoChoices {
1096            provider: "google".to_string(),
1097        });
1098    };
1099
1100    // For simplicity, still taking the first candidate. Multi-candidate handling could be added.
1101    // Access candidates safely since we've checked it's not None or empty.
1102    let candidate = &candidates[0];
1103
1104    // Log finish reason and safety ratings if present
1105    if let Some(reason) = &candidate.finish_reason {
1106        match reason {
1107            GeminiFinishReason::Stop => { /* Normal completion */ }
1108            GeminiFinishReason::MaxTokens => {
1109                warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1110            }
1111            GeminiFinishReason::Safety => {
1112                warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1113                // Consider returning an error or modifying the response based on safety ratings
1114            }
1115            GeminiFinishReason::Recitation => {
1116                warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1117            }
1118            GeminiFinishReason::MalformedFunctionCall => {
1119                warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1120            }
1121            _ => {
1122                info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1123            }
1124        }
1125    }
1126
1127    // Log usage metadata if present
1128    if let Some(usage) = &response.usage_metadata {
1129        debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1130               usage.prompt, usage.candidates, usage.total);
1131    }
1132    let usage = response.usage_metadata.as_ref().and_then(map_gemini_usage);
1133
1134    let content: Vec<AssistantContent> = candidate
1135        .content // GeminiContentResponse
1136        .parts   // Vec<GeminiResponsePart> (struct)
1137        .iter()
1138        .filter_map(|part| { // part is &GeminiResponsePart (struct)
1139            // Check if this is a thought part first
1140            if part.thought {
1141                debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1142                // For thought parts, extract text content and create a Thought block
1143                if let GeminiResponsePartData::Text { text } = &part.data {
1144                    Some(AssistantContent::Thought {
1145                        thought: ThoughtContent::Simple {
1146                            text: text.clone(),
1147                        },
1148                    })
1149                } else {
1150                    warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1151                    None
1152                }
1153            } else {
1154                // Regular (non-thought) content processing
1155                match &part.data {
1156                    GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1157                        text: text.clone(),
1158                    }),
1159                    GeminiResponsePartData::InlineData { inline_data } => {
1160                        warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1161                        Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1162                    }
1163                    GeminiResponsePartData::FunctionCall { function_call } => {
1164                        Some(AssistantContent::ToolCall {
1165                            tool_call: steer_tools::ToolCall {
1166                                id: uuid::Uuid::new_v4().to_string(), // Generate a synthetic ID
1167                                name: function_call.name.clone(),
1168                                parameters: function_call.args.clone(),
1169                            },
1170                            thought_signature: part
1171                                .thought_signature
1172                                .clone()
1173                                .map(ThoughtSignature::new),
1174                        })
1175                    }
1176                    GeminiResponsePartData::FileData { file_data } => {
1177                        warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1178                        Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1179                    }
1180                     GeminiResponsePartData::ExecutableCode { executable_code } => {
1181                         info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1182                              executable_code.language);
1183                         Some(AssistantContent::Text {
1184                             text: format!(
1185                                 "```{}
1186{}
1187```",
1188                                 executable_code.language.to_lowercase(),
1189                                 executable_code.code
1190                             ),
1191                         })
1192                     }
1193                }
1194            }
1195        })
1196        .collect();
1197
1198    Ok(CompletionResponse { content, usage })
1199}
1200
1201#[async_trait]
1202impl Provider for GeminiClient {
1203    fn name(&self) -> &'static str {
1204        "google"
1205    }
1206
1207    async fn complete(
1208        &self,
1209        model_id: &ModelId,
1210        messages: Vec<AppMessage>,
1211        system: Option<SystemContext>,
1212        tools: Option<Vec<ToolSchema>>,
1213        _call_options: Option<ModelParameters>,
1214        token: CancellationToken,
1215    ) -> Result<CompletionResponse, ApiError> {
1216        let model_name = &model_id.id; // Use the model ID string
1217        let url = format!(
1218            "{}/models/{}:generateContent?key={}",
1219            GEMINI_API_BASE, model_name, self.api_key
1220        );
1221
1222        let gemini_contents = convert_messages(messages)?;
1223
1224        let system_instruction = system
1225            .and_then(|context| context.render())
1226            .map(|instructions| GeminiSystemInstruction {
1227                parts: vec![GeminiRequestPart::Text { text: instructions }],
1228            });
1229
1230        let gemini_tools = tools.map(convert_tools);
1231
1232        // Derive generation config from call options, respecting model/catalog settings
1233        let (temperature, top_p, max_output_tokens) = {
1234            let opts = _call_options.as_ref();
1235            (
1236                opts.and_then(|o| o.temperature).or(Some(1.0)),
1237                opts.and_then(|o| o.top_p).or(Some(0.95)),
1238                opts.and_then(|o| o.max_tokens)
1239                    .map(|v| v as i32)
1240                    .or(Some(65536)),
1241            )
1242        };
1243        let thinking_config = _call_options
1244            .as_ref()
1245            .and_then(|o| o.thinking_config)
1246            .and_then(|tc| {
1247                if tc.enabled {
1248                    Some(GeminiThinkingConfig {
1249                        include_thoughts: tc.include_thoughts,
1250                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1251                    })
1252                } else {
1253                    None
1254                }
1255            });
1256
1257        let request = GeminiRequest {
1258            contents: gemini_contents,
1259            system_instruction,
1260            tools: gemini_tools,
1261            generation_config: Some(GeminiGenerationConfig {
1262                max_output_tokens,
1263                temperature,
1264                top_p,
1265                thinking_config,
1266                ..Default::default()
1267            }),
1268        };
1269
1270        let response = tokio::select! {
1271            biased;
1272            () = token.cancelled() => {
1273                debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1274                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1275            }
1276            res = self.client.post(&url).json(&request).send() => {
1277                res.map_err(ApiError::Network)?
1278            }
1279        };
1280        let status = response.status();
1281
1282        if status != StatusCode::OK {
1283            let error_text = response.text().await.map_err(ApiError::Network)?;
1284            error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1285            return Err(match status.as_u16() {
1286                401 | 403 => ApiError::AuthenticationFailed {
1287                    provider: self.name().to_string(),
1288                    details: error_text,
1289                },
1290                429 => ApiError::RateLimited {
1291                    provider: self.name().to_string(),
1292                    details: error_text,
1293                },
1294                400 | 404 => {
1295                    error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1296                    ApiError::InvalidRequest {
1297                        provider: self.name().to_string(),
1298                        details: error_text,
1299                    }
1300                } // 404 might mean invalid model
1301                500..=599 => ApiError::ServerError {
1302                    provider: self.name().to_string(),
1303                    status_code: status.as_u16(),
1304                    details: error_text,
1305                },
1306                _ => ApiError::Unknown {
1307                    provider: self.name().to_string(),
1308                    details: error_text,
1309                },
1310            });
1311        }
1312
1313        let response_text = response.text().await.map_err(ApiError::Network)?;
1314
1315        match serde_json::from_str::<GeminiResponse>(&response_text) {
1316            Ok(gemini_response) => {
1317                convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1318                    provider: self.name().to_string(),
1319                    details: e.to_string(),
1320                })
1321            }
1322            Err(e) => {
1323                error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1324                Err(ApiError::ResponseParsingError {
1325                    provider: self.name().to_string(),
1326                    details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1327                })
1328            }
1329        }
1330    }
1331
1332    async fn stream_complete(
1333        &self,
1334        model_id: &ModelId,
1335        messages: Vec<AppMessage>,
1336        system: Option<SystemContext>,
1337        tools: Option<Vec<ToolSchema>>,
1338        _call_options: Option<ModelParameters>,
1339        token: CancellationToken,
1340    ) -> Result<CompletionStream, ApiError> {
1341        let model_name = &model_id.id;
1342        let url = format!(
1343            "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1344            GEMINI_API_BASE, model_name, self.api_key
1345        );
1346
1347        let gemini_contents = convert_messages(messages)?;
1348
1349        let system_instruction = system
1350            .and_then(|context| context.render())
1351            .map(|instructions| GeminiSystemInstruction {
1352                parts: vec![GeminiRequestPart::Text { text: instructions }],
1353            });
1354
1355        let gemini_tools = tools.map(convert_tools);
1356
1357        let (temperature, top_p, max_output_tokens) = {
1358            let opts = _call_options.as_ref();
1359            (
1360                opts.and_then(|o| o.temperature).or(Some(1.0)),
1361                opts.and_then(|o| o.top_p).or(Some(0.95)),
1362                opts.and_then(|o| o.max_tokens)
1363                    .map(|v| v as i32)
1364                    .or(Some(65536)),
1365            )
1366        };
1367        let thinking_config = _call_options
1368            .as_ref()
1369            .and_then(|o| o.thinking_config)
1370            .and_then(|tc| {
1371                if tc.enabled {
1372                    Some(GeminiThinkingConfig {
1373                        include_thoughts: tc.include_thoughts,
1374                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1375                    })
1376                } else {
1377                    None
1378                }
1379            });
1380
1381        let request = GeminiRequest {
1382            contents: gemini_contents,
1383            system_instruction,
1384            tools: gemini_tools,
1385            generation_config: Some(GeminiGenerationConfig {
1386                max_output_tokens,
1387                temperature,
1388                top_p,
1389                thinking_config,
1390                ..Default::default()
1391            }),
1392        };
1393
1394        let response = tokio::select! {
1395            biased;
1396            () = token.cancelled() => {
1397                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1398            }
1399            res = self.client.post(&url).json(&request).send() => {
1400                res.map_err(ApiError::Network)?
1401            }
1402        };
1403
1404        let status = response.status();
1405        if status != StatusCode::OK {
1406            let error_text = response.text().await.map_err(ApiError::Network)?;
1407            error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1408            return Err(match status.as_u16() {
1409                401 | 403 => ApiError::AuthenticationFailed {
1410                    provider: self.name().to_string(),
1411                    details: error_text,
1412                },
1413                429 => ApiError::RateLimited {
1414                    provider: self.name().to_string(),
1415                    details: error_text,
1416                },
1417                400 | 404 => ApiError::InvalidRequest {
1418                    provider: self.name().to_string(),
1419                    details: error_text,
1420                },
1421                500..=599 => ApiError::ServerError {
1422                    provider: self.name().to_string(),
1423                    status_code: status.as_u16(),
1424                    details: error_text,
1425                },
1426                _ => ApiError::Unknown {
1427                    provider: self.name().to_string(),
1428                    details: error_text,
1429                },
1430            });
1431        }
1432
1433        let byte_stream = response.bytes_stream();
1434        let sse_stream = parse_sse_stream(byte_stream);
1435
1436        Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1437    }
1438}
1439
1440impl GeminiClient {
1441    fn convert_gemini_stream(
1442        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1443        + Unpin
1444        + Send
1445        + 'static,
1446        token: CancellationToken,
1447    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1448        async_stream::stream! {
1449            let mut content: Vec<AssistantContent> = Vec::new();
1450            let mut latest_usage: Option<TokenUsage> = None;
1451            loop {
1452                if token.is_cancelled() {
1453                    yield StreamChunk::Error(StreamError::Cancelled);
1454                    break;
1455                }
1456
1457                let event_result = tokio::select! {
1458                    biased;
1459                    () = token.cancelled() => {
1460                        yield StreamChunk::Error(StreamError::Cancelled);
1461                        break;
1462                    }
1463                    event = sse_stream.next() => event
1464                };
1465
1466                let Some(event_result) = event_result else {
1467                    let content = std::mem::take(&mut content);
1468                    yield StreamChunk::MessageComplete(CompletionResponse {
1469                        content,
1470                        usage: latest_usage,
1471                    });
1472                    break;
1473                };
1474
1475                let event = match event_result {
1476                    Ok(e) => e,
1477                    Err(e) => {
1478                        yield StreamChunk::Error(StreamError::SseParse(e));
1479                        break;
1480                    }
1481                };
1482
1483                let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1484                    Ok(c) => c,
1485                    Err(e) => {
1486                        debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1487                        continue;
1488                    }
1489                };
1490
1491                if let Some(usage) = chunk.usage_metadata.as_ref()
1492                    && let Some(mapped) = map_gemini_usage(usage) {
1493                        latest_usage = Some(mapped);
1494                    }
1495
1496                if let Some(candidates) = chunk.candidates {
1497                    for candidate in candidates {
1498                        for part in candidate.content.parts {
1499                            let GeminiResponsePart {
1500                                thought,
1501                                thought_signature,
1502                                data,
1503                            } = part;
1504                            let thought_signature =
1505                                thought_signature.map(ThoughtSignature::new);
1506
1507                            if thought {
1508                                if let GeminiResponsePartData::Text { text } = data {
1509                                    match content.last_mut() {
1510                                        Some(AssistantContent::Thought {
1511                                            thought: ThoughtContent::Simple { text: buf },
1512                                        }) => buf.push_str(&text),
1513                                        _ => {
1514                                            content.push(AssistantContent::Thought {
1515                                                thought: ThoughtContent::Simple { text: text.clone() },
1516                                            });
1517                                        }
1518                                    }
1519                                    yield StreamChunk::ThinkingDelta(text);
1520                                }
1521                            } else {
1522                                match data {
1523                                    GeminiResponsePartData::Text { text } => {
1524                                        match content.last_mut() {
1525                                            Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1526                                            _ => {
1527                                                content.push(AssistantContent::Text { text: text.clone() });
1528                                            }
1529                                        }
1530                                        yield StreamChunk::TextDelta(text);
1531                                    }
1532                                    GeminiResponsePartData::FunctionCall { function_call } => {
1533                                        let id = uuid::Uuid::new_v4().to_string();
1534                                        content.push(AssistantContent::ToolCall {
1535                                            tool_call: steer_tools::ToolCall {
1536                                                id: id.clone(),
1537                                                name: function_call.name.clone(),
1538                                                parameters: function_call.args.clone(),
1539                                            },
1540                                            thought_signature,
1541                                        });
1542                                        yield StreamChunk::ToolUseStart {
1543                                            id: id.clone(),
1544                                            name: function_call.name,
1545                                        };
1546                                        yield StreamChunk::ToolUseInputDelta {
1547                                            id,
1548                                            delta: function_call.args.to_string(),
1549                                        };
1550                                    }
1551                                    _ => {}
1552                                }
1553                            }
1554                        }
1555                    }
1556                }
1557            }
1558        }
1559    }
1560}
1561
1562#[cfg(test)]
1563mod tests {
1564    use super::*;
1565    use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1566    use serde_json::json;
1567
1568    #[test]
1569    fn test_simplify_property_schema_removes_additional_properties() {
1570        let property_value = json!({
1571            "type": "object",
1572            "properties": {
1573                "name": {"type": "string"}
1574            },
1575            "additionalProperties": false
1576        });
1577
1578        let expected = json!({
1579            "type": "object",
1580            "properties": {
1581                "name": {"type": "string"}
1582            }
1583        });
1584
1585        let result = simplify_property_schema("testProp", "testTool", &property_value);
1586        assert_eq!(result, expected);
1587    }
1588
1589    #[test]
1590    fn test_simplify_property_schema_removes_unsupported_string_formats() {
1591        let property_value = json!({
1592            "type": "string",
1593            "format": "uri",
1594            "minLength": 1,
1595            "maxLength": 100,
1596            "pattern": "^https://"
1597        });
1598
1599        let expected = json!({
1600            "type": "string"
1601        });
1602
1603        let result = simplify_property_schema("urlProp", "testTool", &property_value);
1604        assert_eq!(result, expected);
1605    }
1606
1607    #[test]
1608    fn test_simplify_property_schema_keeps_supported_string_formats() {
1609        let property_value = json!({
1610            "type": "string",
1611            "format": "date-time"
1612        });
1613
1614        let expected = json!({
1615            "type": "string",
1616            "format": "date-time"
1617        });
1618
1619        let result = simplify_property_schema("dateProp", "testTool", &property_value);
1620        assert_eq!(result, expected);
1621    }
1622
1623    #[test]
1624    fn test_simplify_property_schema_handles_array_types() {
1625        let property_value = json!({
1626            "type": ["string", "null"],
1627            "format": "email"
1628        });
1629
1630        let expected = json!({
1631            "type": "string"
1632        });
1633
1634        let result = simplify_property_schema("emailProp", "testTool", &property_value);
1635        assert_eq!(result, expected);
1636    }
1637
1638    #[test]
1639    fn test_simplify_property_schema_recursively_handles_array_items() {
1640        let property_value = json!({
1641            "type": "array",
1642            "items": {
1643                "type": "object",
1644                "properties": {
1645                    "url": {
1646                        "type": "string",
1647                        "format": "uri"
1648                    }
1649                },
1650                "additionalProperties": false
1651            }
1652        });
1653
1654        let expected = json!({
1655            "type": "array",
1656            "items": {
1657                "type": "object",
1658                "properties": {
1659                    "url": {
1660                        "type": "string"
1661                    }
1662                }
1663            }
1664        });
1665
1666        let result = simplify_property_schema("linksProp", "testTool", &property_value);
1667        assert_eq!(result, expected);
1668    }
1669
1670    #[test]
1671    fn test_simplify_property_schema_recursively_handles_nested_objects() {
1672        let property_value = json!({
1673            "type": "object",
1674            "properties": {
1675                "nested": {
1676                    "type": "object",
1677                    "properties": {
1678                        "field": {
1679                            "type": "string",
1680                            "format": "hostname"
1681                        }
1682                    },
1683                    "additionalProperties": true
1684                }
1685            },
1686            "additionalProperties": false
1687        });
1688
1689        let expected = json!({
1690            "type": "object",
1691            "properties": {
1692                "nested": {
1693                    "type": "object",
1694                    "properties": {
1695                        "field": {
1696                            "type": "string"
1697                        }
1698                    }
1699                }
1700            }
1701        });
1702
1703        let result = simplify_property_schema("complexProp", "testTool", &property_value);
1704        assert_eq!(result, expected);
1705    }
1706
1707    #[test]
1708    fn test_simplify_property_schema_fixes_uint64_format() {
1709        let property_value = json!({
1710            "type": "integer",
1711            "format": "uint64"
1712        });
1713
1714        let expected = json!({
1715            "type": "integer",
1716            "format": "int64"
1717        });
1718
1719        let result = simplify_property_schema("idProp", "testTool", &property_value);
1720        assert_eq!(result, expected);
1721    }
1722
1723    #[test]
1724    fn test_convert_tools_integration() {
1725        use steer_tools::{InputSchema, ToolSchema};
1726
1727        let tool = ToolSchema {
1728            name: "create_issue".to_string(),
1729            display_name: "Create Issue".to_string(),
1730            description: "Create an issue".to_string(),
1731            input_schema: InputSchema::object(
1732                {
1733                    let mut props = serde_json::Map::new();
1734                    props.insert(
1735                        "title".to_string(),
1736                        json!({
1737                            "type": "string",
1738                            "minLength": 1
1739                        }),
1740                    );
1741                    props.insert(
1742                        "links".to_string(),
1743                        json!({
1744                            "type": "array",
1745                            "items": {
1746                                "type": "object",
1747                                "properties": {
1748                                    "url": {
1749                                        "type": "string",
1750                                        "format": "uri"
1751                                    }
1752                                },
1753                                "additionalProperties": false
1754                            }
1755                        }),
1756                    );
1757                    props
1758                },
1759                vec!["title".to_string()],
1760            ),
1761        };
1762
1763        let expected_tools = vec![GeminiTool {
1764            function_declarations: vec![GeminiFunctionDeclaration {
1765                name: "create_issue".to_string(),
1766                description: "Create an issue".to_string(),
1767                parameters: GeminiParameterSchema {
1768                    schema_type: "object".to_string(),
1769                    properties: {
1770                        let mut props = serde_json::Map::new();
1771                        props.insert(
1772                            "title".to_string(),
1773                            json!({
1774                                "type": "string"
1775                            }),
1776                        );
1777                        props.insert(
1778                            "links".to_string(),
1779                            json!({
1780                                "type": "array",
1781                                "items": {
1782                                    "type": "object",
1783                                    "properties": {
1784                                        "url": {
1785                                            "type": "string"
1786                                        }
1787                                    }
1788                                }
1789                            }),
1790                        );
1791                        props
1792                    },
1793                    required: vec!["title".to_string()],
1794                },
1795            }],
1796        }];
1797
1798        let result = convert_tools(vec![tool]);
1799        assert_eq!(result, expected_tools);
1800    }
1801
1802    #[test]
1803    fn test_convert_messages_includes_data_url_image_as_inline_data() {
1804        let messages = vec![Message {
1805            data: MessageData::User {
1806                content: vec![UserContent::Image {
1807                    image: ImageContent {
1808                        mime_type: "image/png".to_string(),
1809                        source: ImageSource::DataUrl {
1810                            data_url: "".to_string(),
1811                        },
1812                        width: None,
1813                        height: None,
1814                        bytes: None,
1815                        sha256: None,
1816                    },
1817                }],
1818            },
1819            timestamp: 1,
1820            id: "msg-image".to_string(),
1821            parent_message_id: None,
1822        }];
1823
1824        let converted = convert_messages(messages).expect("convert messages");
1825        assert_eq!(converted.len(), 1);
1826        let first = &converted[0];
1827        assert_eq!(first.role, "user");
1828        assert_eq!(first.parts.len(), 1);
1829
1830        let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1831        assert_eq!(json["inlineData"]["mimeType"], "image/png");
1832        assert_eq!(json["inlineData"]["data"], "Zm9v");
1833    }
1834
1835    #[test]
1836    fn test_convert_messages_rejects_session_file_image_source() {
1837        let messages = vec![Message {
1838            data: MessageData::User {
1839                content: vec![UserContent::Image {
1840                    image: ImageContent {
1841                        mime_type: "image/png".to_string(),
1842                        source: ImageSource::SessionFile {
1843                            relative_path: "session-1/image.png".to_string(),
1844                        },
1845                        width: None,
1846                        height: None,
1847                        bytes: None,
1848                        sha256: None,
1849                    },
1850                }],
1851            },
1852            timestamp: 1,
1853            id: "msg-image".to_string(),
1854            parent_message_id: None,
1855        }];
1856
1857        let err = convert_messages(messages).expect_err("expected unsupported feature");
1858        match err {
1859            ApiError::UnsupportedFeature {
1860                provider,
1861                feature,
1862                details,
1863            } => {
1864                assert_eq!(provider, "google");
1865                assert_eq!(feature, "image input source");
1866                assert!(details.contains("session file"));
1867            }
1868            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1869        }
1870    }
1871
1872    #[test]
1873    fn test_map_gemini_usage_derives_missing_counter() {
1874        let usage = GeminiUsageMetadata {
1875            prompt: None,
1876            candidates: Some(7),
1877            total: Some(19),
1878        };
1879
1880        assert_eq!(map_gemini_usage(&usage), Some(TokenUsage::new(12, 7, 19)));
1881    }
1882
1883    #[test]
1884    fn test_map_gemini_usage_returns_none_when_counters_not_derivable() {
1885        let usage = GeminiUsageMetadata {
1886            prompt: Some(9),
1887            candidates: None,
1888            total: None,
1889        };
1890
1891        assert_eq!(map_gemini_usage(&usage), None);
1892    }
1893
1894    #[test]
1895    fn test_convert_response_maps_usage_metadata() {
1896        let response = GeminiResponse {
1897            candidates: Some(vec![GeminiCandidate {
1898                content: GeminiContentResponse {
1899                    role: "model".to_string(),
1900                    parts: vec![GeminiResponsePart {
1901                        thought: false,
1902                        thought_signature: None,
1903                        data: GeminiResponsePartData::Text {
1904                            text: "Hello".to_string(),
1905                        },
1906                    }],
1907                },
1908                finish_reason: Some(GeminiFinishReason::Stop),
1909                safety_ratings: None,
1910                citation_metadata: None,
1911            }]),
1912            prompt_feedback: None,
1913            usage_metadata: Some(GeminiUsageMetadata {
1914                prompt: Some(11),
1915                candidates: Some(5),
1916                total: Some(16),
1917            }),
1918        };
1919
1920        let converted = convert_response(response).expect("response should convert");
1921
1922        assert_eq!(converted.usage, Some(TokenUsage::new(11, 5, 16)));
1923        assert!(matches!(
1924            converted.content.first(),
1925            Some(AssistantContent::Text { text }) if text == "Hello"
1926        ));
1927    }
1928
1929    #[tokio::test]
1930    async fn test_convert_gemini_stream_text_deltas() {
1931        use crate::api::provider::StreamChunk;
1932        use crate::api::sse::SseEvent;
1933        use futures::StreamExt;
1934        use futures::stream;
1935        use std::pin::pin;
1936        use tokio_util::sync::CancellationToken;
1937
1938        let events = vec![
1939            Ok(SseEvent {
1940                event_type: None,
1941                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1942                    .to_string(),
1943                id: None,
1944            }),
1945            Ok(SseEvent {
1946                event_type: None,
1947                data:
1948                    r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1949                        .to_string(),
1950                id: None,
1951            }),
1952        ];
1953
1954        let sse_stream = stream::iter(events);
1955        let token = CancellationToken::new();
1956        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1957
1958        let first_delta = stream.next().await.unwrap();
1959        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1960
1961        let second_delta = stream.next().await.unwrap();
1962        assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1963
1964        let complete = stream.next().await.unwrap();
1965        assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1966    }
1967
1968    #[tokio::test]
1969    async fn test_convert_gemini_stream_captures_final_usage() {
1970        use crate::api::provider::StreamChunk;
1971        use crate::api::sse::SseEvent;
1972        use futures::StreamExt;
1973        use futures::stream;
1974        use std::pin::pin;
1975        use tokio_util::sync::CancellationToken;
1976
1977        let events = vec![
1978            Ok(SseEvent {
1979                event_type: None,
1980                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1981                    .to_string(),
1982                id: None,
1983            }),
1984            Ok(SseEvent {
1985                event_type: None,
1986                data: r#"{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}}"#
1987                    .to_string(),
1988                id: None,
1989            }),
1990        ];
1991
1992        let sse_stream = stream::iter(events);
1993        let token = CancellationToken::new();
1994        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1995
1996        let first_delta = stream.next().await.unwrap();
1997        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1998
1999        let complete = stream.next().await.unwrap();
2000        if let StreamChunk::MessageComplete(response) = complete {
2001            assert_eq!(response.usage, Some(TokenUsage::new(10, 4, 14)));
2002            assert!(matches!(
2003                response.content.first(),
2004                Some(AssistantContent::Text { text }) if text == "Hello"
2005            ));
2006        } else {
2007            panic!("Expected MessageComplete");
2008        }
2009    }
2010
2011    #[tokio::test]
2012    async fn test_convert_gemini_stream_with_thinking() {
2013        use crate::api::provider::StreamChunk;
2014        use crate::api::sse::SseEvent;
2015        use futures::StreamExt;
2016        use futures::stream;
2017        use std::pin::pin;
2018        use tokio_util::sync::CancellationToken;
2019
2020        let events = vec![
2021            Ok(SseEvent {
2022                event_type: None,
2023                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
2024                id: None,
2025            }),
2026            Ok(SseEvent {
2027                event_type: None,
2028                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
2029                id: None,
2030            }),
2031        ];
2032
2033        let sse_stream = stream::iter(events);
2034        let token = CancellationToken::new();
2035        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2036
2037        let thinking_delta = stream.next().await.unwrap();
2038        assert!(
2039            matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
2040        );
2041
2042        let text_delta = stream.next().await.unwrap();
2043        assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
2044
2045        let complete = stream.next().await.unwrap();
2046        if let StreamChunk::MessageComplete(response) = complete {
2047            assert_eq!(response.content.len(), 2);
2048            assert!(matches!(
2049                &response.content[0],
2050                AssistantContent::Thought { .. }
2051            ));
2052            assert!(matches!(
2053                &response.content[1],
2054                AssistantContent::Text { .. }
2055            ));
2056        } else {
2057            panic!("Expected MessageComplete");
2058        }
2059    }
2060
2061    #[tokio::test]
2062    async fn test_convert_gemini_stream_with_function_call() {
2063        use crate::api::provider::StreamChunk;
2064        use crate::api::sse::SseEvent;
2065        use futures::StreamExt;
2066        use futures::stream;
2067        use std::pin::pin;
2068        use tokio_util::sync::CancellationToken;
2069
2070        let events = vec![
2071            Ok(SseEvent {
2072                event_type: None,
2073                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
2074                id: None,
2075            }),
2076        ];
2077
2078        let sse_stream = stream::iter(events);
2079        let token = CancellationToken::new();
2080        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2081
2082        let tool_start = stream.next().await.unwrap();
2083        assert!(
2084            matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
2085        );
2086
2087        let tool_input = stream.next().await.unwrap();
2088        assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
2089
2090        let complete = stream.next().await.unwrap();
2091        if let StreamChunk::MessageComplete(response) = complete {
2092            assert_eq!(response.content.len(), 1);
2093            if let AssistantContent::ToolCall {
2094                tool_call,
2095                thought_signature,
2096            } = &response.content[0]
2097            {
2098                assert_eq!(tool_call.name, "get_weather");
2099                assert_eq!(
2100                    thought_signature.as_ref().map(|sig| sig.as_str()),
2101                    Some("sig_123")
2102                );
2103            } else {
2104                panic!("Expected ToolCall");
2105            }
2106        } else {
2107            panic!("Expected MessageComplete");
2108        }
2109    }
2110
2111    #[tokio::test]
2112    async fn test_convert_gemini_stream_cancellation() {
2113        use crate::api::error::StreamError;
2114        use crate::api::provider::StreamChunk;
2115        use crate::api::sse::SseEvent;
2116        use futures::StreamExt;
2117        use futures::stream;
2118        use std::pin::pin;
2119        use tokio_util::sync::CancellationToken;
2120
2121        let events = vec![Ok(SseEvent {
2122            event_type: None,
2123            data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
2124                .to_string(),
2125            id: None,
2126        })];
2127
2128        let sse_stream = stream::iter(events);
2129        let token = CancellationToken::new();
2130        token.cancel();
2131
2132        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2133
2134        let cancelled = stream.next().await.unwrap();
2135        assert!(matches!(
2136            cancelled,
2137            StreamChunk::Error(StreamError::Cancelled)
2138        ));
2139    }
2140
2141    #[tokio::test]
2142    #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2143    async fn test_stream_complete_real_api() {
2144        use crate::api::Provider;
2145        use crate::api::provider::StreamChunk;
2146        use crate::app::conversation::{Message, MessageData, UserContent};
2147        use futures::StreamExt;
2148        use tokio_util::sync::CancellationToken;
2149
2150        dotenvy::dotenv().ok();
2151        let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2152        let client = GeminiClient::new(api_key);
2153
2154        let message = Message {
2155            data: MessageData::User {
2156                content: vec![UserContent::Text {
2157                    text: "Say exactly: Hello".to_string(),
2158                }],
2159            },
2160            timestamp: chrono::Utc::now().timestamp_millis() as u64,
2161            id: "test-msg".to_string(),
2162            parent_message_id: None,
2163        };
2164
2165        let model_id = ModelId::new(
2166            crate::config::provider::google(),
2167            "gemini-2.5-flash-preview-04-17",
2168        );
2169        let token = CancellationToken::new();
2170
2171        let mut stream = client
2172            .stream_complete(&model_id, vec![message], None, None, None, token)
2173            .await
2174            .expect("stream_complete should succeed");
2175
2176        let mut got_text_delta = false;
2177        let mut got_message_complete = false;
2178        let mut accumulated_text = String::new();
2179
2180        while let Some(chunk) = stream.next().await {
2181            match chunk {
2182                StreamChunk::TextDelta(text) => {
2183                    got_text_delta = true;
2184                    accumulated_text.push_str(&text);
2185                }
2186                StreamChunk::MessageComplete(response) => {
2187                    got_message_complete = true;
2188                    assert!(!response.content.is_empty());
2189                }
2190                StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2191                _ => {}
2192            }
2193        }
2194
2195        assert!(got_text_delta, "Should receive at least one TextDelta");
2196        assert!(
2197            got_message_complete,
2198            "Should receive MessageComplete at the end"
2199        );
2200        assert!(
2201            accumulated_text.to_lowercase().contains("hello"),
2202            "Response should contain 'hello', got: {accumulated_text}"
2203        );
2204    }
2205}