Skip to main content

steer_core/api/gemini/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::app::SystemContext;
13use crate::app::conversation::{
14    AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
15    ToolResult, UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
21
22#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone for potential future use
23struct GeminiBlob {
24    #[serde(rename = "mimeType")]
25    mime_type: String,
26    data: String, // Assuming base64 encoded data
27}
28
29#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
30struct GeminiFileData {
31    #[serde(rename = "mimeType")]
32    mime_type: String,
33    #[serde(rename = "fileUri")]
34    file_uri: String,
35}
36
37#[expect(dead_code)]
38#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
39struct GeminiCodeExecutionResult {
40    outcome: String, // e.g., "OK", "ERROR"
41                     // Potentially add output field later if needed
42}
43
44pub struct GeminiClient {
45    api_key: String,
46    client: HttpClient,
47}
48
49impl GeminiClient {
50    pub fn new(api_key: impl Into<String>) -> Self {
51        Self {
52            api_key: api_key.into(),
53            client: HttpClient::new(),
54        }
55    }
56}
57
58#[derive(Debug, Serialize)]
59struct GeminiRequest {
60    contents: Vec<GeminiContent>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    #[serde(rename = "systemInstruction")]
63    system_instruction: Option<GeminiSystemInstruction>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    tools: Option<Vec<GeminiTool>>,
66    #[serde(skip_serializing_if = "Option::is_none")]
67    #[serde(rename = "generationConfig")]
68    generation_config: Option<GeminiGenerationConfig>,
69}
70
71#[derive(Debug, Serialize, Default, Clone)]
72struct GeminiGenerationConfig {
73    #[serde(skip_serializing_if = "Option::is_none")]
74    #[serde(rename = "stopSequences")]
75    stop_sequences: Option<Vec<String>>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    #[serde(rename = "responseMimeType")]
78    response_mime_type: Option<GeminiMimeType>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    #[serde(rename = "candidateCount")]
81    candidate_count: Option<i32>,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    #[serde(rename = "maxOutputTokens")]
84    max_output_tokens: Option<i32>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    temperature: Option<f32>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    #[serde(rename = "topP")]
89    top_p: Option<f32>,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    #[serde(rename = "topK")]
92    top_k: Option<i32>,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    #[serde(rename = "thinkingConfig")]
95    thinking_config: Option<GeminiThinkingConfig>,
96}
97
98#[derive(Debug, Serialize, Default, Clone)]
99struct GeminiThinkingConfig {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    #[serde(rename = "includeThoughts")]
102    include_thoughts: Option<bool>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    #[serde(rename = "thinkingBudget")]
105    thinking_budget: Option<i32>,
106}
107
108#[expect(dead_code)]
109#[derive(Debug, Serialize, Clone)]
110#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
111enum GeminiMimeType {
112    MimeTypeUnspecified,
113    TextPlain,
114    ApplicationJson,
115}
116
117#[derive(Debug, Serialize)]
118struct GeminiSystemInstruction {
119    parts: Vec<GeminiRequestPart>,
120}
121
122#[derive(Debug, Serialize)]
123struct GeminiContent {
124    role: String,
125    parts: Vec<GeminiRequestPart>,
126}
127
128// Enum for parts used ONLY in requests
129#[derive(Debug, Serialize)]
130#[serde(untagged)]
131enum GeminiRequestPart {
132    Text {
133        text: String,
134    },
135    #[serde(rename = "inlineData")]
136    InlineData {
137        #[serde(rename = "inlineData")]
138        inline_data: GeminiBlob,
139    },
140    #[serde(rename = "fileData")]
141    FileData {
142        #[serde(rename = "fileData")]
143        file_data: GeminiFileData,
144    },
145    #[serde(rename = "functionCall")]
146    FunctionCall {
147        #[serde(rename = "functionCall")]
148        function_call: GeminiFunctionCall, // Used for model turns in history
149        #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
150        thought_signature: Option<String>,
151    },
152    #[serde(rename = "functionResponse")]
153    FunctionResponse {
154        #[serde(rename = "functionResponse")]
155        function_response: GeminiFunctionResponse, // Used for function/tool turns
156    },
157}
158
159// Enum for parts received ONLY in responses
160#[derive(Debug, Deserialize)]
161#[serde(untagged)]
162enum GeminiResponsePartData {
163    Text {
164        text: String,
165    },
166    #[serde(rename = "inlineData")]
167    InlineData {
168        #[serde(rename = "inlineData")]
169        inline_data: GeminiBlob,
170    },
171    #[serde(rename = "functionCall")]
172    FunctionCall {
173        #[serde(rename = "functionCall")]
174        function_call: GeminiFunctionCall,
175    },
176    #[serde(rename = "fileData")]
177    FileData {
178        #[serde(rename = "fileData")]
179        file_data: GeminiFileData,
180    },
181    #[serde(rename = "executableCode")]
182    ExecutableCode {
183        #[serde(rename = "executableCode")]
184        executable_code: GeminiExecutableCode,
185    },
186    // Add other variants back here if needed
187}
188
189// 2. Change GeminiResponsePart to a struct
190#[derive(Debug, Deserialize)]
191struct GeminiResponsePart {
192    #[serde(default)] // Defaults to false if missing
193    thought: bool,
194    #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
195    thought_signature: Option<String>,
196
197    #[serde(flatten)] // Look for data fields directly in this struct's JSON
198    data: GeminiResponsePartData,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202struct GeminiFunctionCall {
203    name: String,
204    args: Value,
205}
206
207#[derive(Debug, Serialize, PartialEq)]
208struct GeminiTool {
209    #[serde(rename = "functionDeclarations")]
210    function_declarations: Vec<GeminiFunctionDeclaration>,
211}
212
213#[derive(Debug, Serialize, PartialEq)]
214struct GeminiFunctionDeclaration {
215    name: String,
216    description: String,
217    parameters: GeminiParameterSchema,
218}
219
220#[derive(Debug, Serialize, PartialEq)]
221struct GeminiParameterSchema {
222    #[serde(rename = "type")]
223    schema_type: String, // Typically "object"
224    properties: serde_json::Map<String, Value>,
225    required: Vec<String>,
226}
227
228#[derive(Debug, Deserialize)]
229struct GeminiResponse {
230    #[serde(rename = "candidates")]
231    #[serde(skip_serializing_if = "Option::is_none")]
232    candidates: Option<Vec<GeminiCandidate>>,
233    #[serde(rename = "promptFeedback")]
234    #[serde(skip_serializing_if = "Option::is_none")]
235    prompt_feedback: Option<GeminiPromptFeedback>,
236    #[serde(rename = "usageMetadata")]
237    #[serde(skip_serializing_if = "Option::is_none")]
238    usage_metadata: Option<GeminiUsageMetadata>,
239}
240
241#[derive(Debug, Deserialize)]
242struct GeminiCandidate {
243    content: GeminiContentResponse,
244    #[serde(rename = "finishReason")]
245    #[serde(skip_serializing_if = "Option::is_none")]
246    finish_reason: Option<GeminiFinishReason>,
247    #[serde(rename = "safetyRatings")]
248    #[serde(skip_serializing_if = "Option::is_none")]
249    safety_ratings: Option<Vec<GeminiSafetyRating>>,
250    #[serde(rename = "citationMetadata")]
251    #[serde(skip_serializing_if = "Option::is_none")]
252    citation_metadata: Option<GeminiCitationMetadata>,
253}
254
255#[derive(Debug, Deserialize)]
256#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
257enum GeminiFinishReason {
258    FinishReasonUnspecified,
259    Stop,
260    MaxTokens,
261    Safety,
262    Recitation,
263    Other,
264    #[serde(rename = "TOOL_CODE_ERROR")]
265    ToolCodeError,
266    #[serde(rename = "TOOL_EXECUTION_HALT")]
267    ToolExecutionHalt,
268    MalformedFunctionCall,
269}
270
271#[derive(Debug, Deserialize)]
272struct GeminiPromptFeedback {
273    #[serde(rename = "blockReason")]
274    #[serde(skip_serializing_if = "Option::is_none")]
275    block_reason: Option<GeminiBlockReason>,
276    #[serde(rename = "safetyRatings")]
277    #[serde(skip_serializing_if = "Option::is_none")]
278    safety_ratings: Option<Vec<GeminiSafetyRating>>,
279}
280
281#[derive(Debug, Deserialize)]
282#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
283enum GeminiBlockReason {
284    BlockReasonUnspecified,
285    Safety,
286    Other,
287}
288
289#[derive(Debug, Deserialize)]
290#[expect(dead_code)]
291struct GeminiSafetyRating {
292    category: GeminiHarmCategory,
293    probability: GeminiHarmProbability,
294    #[serde(default)] // Default to false if missing
295    blocked: bool,
296}
297
298#[derive(Debug, Deserialize, Serialize)] // Add Serialize for potential use in SafetySetting
299#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
300#[expect(clippy::enum_variant_names)]
301enum GeminiHarmCategory {
302    HarmCategoryUnspecified,
303    HarmCategoryDerogatory,
304    HarmCategoryToxicity,
305    HarmCategoryViolence,
306    HarmCategorySexual,
307    HarmCategoryMedical,
308    HarmCategoryDangerous,
309    HarmCategoryHarassment,
310    HarmCategoryHateSpeech,
311    HarmCategorySexuallyExplicit,
312    HarmCategoryDangerousContent,
313    HarmCategoryCivicIntegrity,
314}
315
316#[derive(Debug, Deserialize)]
317#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
318enum GeminiHarmProbability {
319    HarmProbabilityUnspecified,
320    Negligible,
321    Low,
322    Medium,
323    High,
324}
325
326#[expect(dead_code)]
327#[derive(Debug, Deserialize)]
328struct GeminiCitationMetadata {
329    #[serde(rename = "citationSources")]
330    #[serde(skip_serializing_if = "Option::is_none")]
331    citation_sources: Option<Vec<GeminiCitationSource>>,
332}
333
334#[expect(dead_code)]
335#[derive(Debug, Deserialize)]
336struct GeminiCitationSource {
337    #[serde(rename = "startIndex")]
338    #[serde(skip_serializing_if = "Option::is_none")]
339    start_index: Option<i32>,
340    #[serde(rename = "endIndex")]
341    #[serde(skip_serializing_if = "Option::is_none")]
342    end_index: Option<i32>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    uri: Option<String>,
345    #[serde(skip_serializing_if = "Option::is_none")]
346    license: Option<String>,
347}
348
349#[derive(Debug, Deserialize)]
350struct GeminiUsageMetadata {
351    #[serde(rename = "promptTokenCount")]
352    #[serde(skip_serializing_if = "Option::is_none")]
353    prompt: Option<i32>,
354    #[serde(rename = "candidatesTokenCount")]
355    #[serde(skip_serializing_if = "Option::is_none")]
356    candidates: Option<i32>,
357    #[serde(rename = "totalTokenCount")]
358    #[serde(skip_serializing_if = "Option::is_none")]
359    total: Option<i32>,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363struct GeminiFunctionResponse {
364    name: String,
365    response: GeminiResponseContent,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369struct GeminiResponseContent {
370    content: Value,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374struct GeminiExecutableCode {
375    language: String, // e.g., PYTHON
376    code: String,
377}
378
379#[derive(Debug, Deserialize)]
380#[expect(dead_code)]
381struct GeminiContentResponse {
382    role: String,
383    parts: Vec<GeminiResponsePart>,
384}
385
386fn user_image_to_request_part(
387    image: &crate::app::conversation::ImageContent,
388) -> Result<GeminiRequestPart, ApiError> {
389    match &image.source {
390        ImageSource::DataUrl { data_url } => {
391            let Some((meta, payload)) = data_url.split_once(',') else {
392                return Err(ApiError::InvalidRequest {
393                    provider: "google".to_string(),
394                    details: "Image data URL is missing ',' separator".to_string(),
395                });
396            };
397
398            let Some(meta_body) = meta.strip_prefix("data:") else {
399                return Err(ApiError::InvalidRequest {
400                    provider: "google".to_string(),
401                    details: "Image data URL must start with 'data:'".to_string(),
402                });
403            };
404
405            if !meta_body
406                .split(';')
407                .any(|segment| segment.eq_ignore_ascii_case("base64"))
408            {
409                return Err(ApiError::InvalidRequest {
410                    provider: "google".to_string(),
411                    details: "Gemini image data URLs must be base64 encoded".to_string(),
412                });
413            }
414
415            let mime_type = meta_body
416                .split(';')
417                .next()
418                .filter(|value| !value.trim().is_empty())
419                .unwrap_or("application/octet-stream")
420                .to_string();
421
422            if !mime_type.starts_with("image/") {
423                return Err(ApiError::UnsupportedFeature {
424                    provider: "google".to_string(),
425                    feature: "image input mime type".to_string(),
426                    details: format!("Unsupported image media type '{}'", mime_type),
427                });
428            }
429
430            Ok(GeminiRequestPart::InlineData {
431                inline_data: GeminiBlob {
432                    mime_type,
433                    data: payload.to_string(),
434                },
435            })
436        }
437        ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
438            file_data: GeminiFileData {
439                mime_type: image.mime_type.clone(),
440                file_uri: url.clone(),
441            },
442        }),
443        ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
444            provider: "google".to_string(),
445            feature: "image input source".to_string(),
446            details: format!(
447                "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
448                relative_path
449            ),
450        }),
451    }
452}
453
454fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
455    let mut converted = Vec::new();
456
457    for msg in messages {
458        match &msg.data {
459            crate::app::conversation::MessageData::User { content, .. } => {
460                let mut parts = Vec::new();
461                for user_content in content {
462                    match user_content {
463                        UserContent::Text { text } => {
464                            parts.push(GeminiRequestPart::Text { text: text.clone() });
465                        }
466                        UserContent::Image { image } => {
467                            parts.push(user_image_to_request_part(image)?);
468                        }
469                        UserContent::CommandExecution {
470                            command,
471                            stdout,
472                            stderr,
473                            exit_code,
474                        } => {
475                            parts.push(GeminiRequestPart::Text {
476                                text: UserContent::format_command_execution_as_xml(
477                                    command, stdout, stderr, *exit_code,
478                                ),
479                            });
480                        }
481                    }
482                }
483
484                if !parts.is_empty() {
485                    converted.push(GeminiContent {
486                        role: "user".to_string(),
487                        parts,
488                    });
489                }
490            }
491            crate::app::conversation::MessageData::Assistant { content, .. } => {
492                let mut parts = Vec::new();
493                for assistant_content in content {
494                    match assistant_content {
495                        AssistantContent::Text { text } => {
496                            parts.push(GeminiRequestPart::Text { text: text.clone() });
497                        }
498                        AssistantContent::Image { image } => {
499                            match user_image_to_request_part(image) {
500                                Ok(part) => parts.push(part),
501                                Err(err) => {
502                                    debug!(
503                                        target: "gemini::convert_messages",
504                                        "Skipping unsupported assistant image block: {}",
505                                        err
506                                    );
507                                }
508                            }
509                        }
510                        AssistantContent::ToolCall {
511                            tool_call,
512                            thought_signature,
513                        } => parts.push(GeminiRequestPart::FunctionCall {
514                            function_call: GeminiFunctionCall {
515                                name: tool_call.name.clone(),
516                                args: tool_call.parameters.clone(),
517                            },
518                            thought_signature: thought_signature
519                                .as_ref()
520                                .map(|signature| signature.as_str().to_string()),
521                        }),
522                        AssistantContent::Thought { .. } => {
523                            // Gemini doesn't send thought blocks in requests
524                        }
525                    }
526                }
527
528                converted.push(GeminiContent {
529                    role: "model".to_string(),
530                    parts,
531                });
532            }
533            crate::app::conversation::MessageData::Tool {
534                tool_use_id,
535                result,
536                ..
537            } => {
538                let result_value = match result {
539                    ToolResult::Error(e) => Value::String(format!("Error: {e}")),
540                    _ => serde_json::to_value(result)
541                        .unwrap_or_else(|_| Value::String(result.llm_format())),
542                };
543
544                let parts = vec![GeminiRequestPart::FunctionResponse {
545                    function_response: GeminiFunctionResponse {
546                        name: tool_use_id.clone(),
547                        response: GeminiResponseContent {
548                            content: result_value,
549                        },
550                    },
551                }];
552
553                converted.push(GeminiContent {
554                    role: "function".to_string(),
555                    parts,
556                });
557            }
558        }
559    }
560
561    Ok(converted)
562}
563
564fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
565    let reference = schema.get("$ref").and_then(|v| v.as_str())?;
566    let path = reference.strip_prefix("#/")?;
567    let mut current = root;
568    for segment in path.split('/') {
569        current = current.get(segment)?;
570    }
571    Some(current)
572}
573
574fn infer_type_from_enum(values: &[Value]) -> Option<String> {
575    let mut has_string = false;
576    let mut has_number = false;
577    let mut has_bool = false;
578    let mut has_object = false;
579    let mut has_array = false;
580
581    for value in values {
582        match value {
583            Value::String(_) => has_string = true,
584            Value::Number(_) => has_number = true,
585            Value::Bool(_) => has_bool = true,
586            Value::Object(_) => has_object = true,
587            Value::Array(_) => has_array = true,
588            Value::Null => {}
589        }
590    }
591
592    let kind_count = u8::from(has_string)
593        + u8::from(has_number)
594        + u8::from(has_bool)
595        + u8::from(has_object)
596        + u8::from(has_array);
597
598    if kind_count != 1 {
599        return None;
600    }
601
602    if has_string {
603        Some("string".to_string())
604    } else if has_number {
605        Some("number".to_string())
606    } else if has_bool {
607        Some("boolean".to_string())
608    } else if has_object {
609        Some("object".to_string())
610    } else if has_array {
611        Some("array".to_string())
612    } else {
613        None
614    }
615}
616
617fn normalize_type(value: &Value) -> Value {
618    if let Some(type_str) = value.as_str() {
619        return Value::String(type_str.to_string());
620    }
621
622    if let Some(type_array) = value.as_array()
623        && let Some(primary_type) = type_array
624            .iter()
625            .find_map(|v| if v.is_null() { None } else { v.as_str() })
626    {
627        return Value::String(primary_type.to_string());
628    }
629
630    Value::String("string".to_string())
631}
632
633fn extract_enum_values(value: &Value) -> Vec<Value> {
634    let Some(obj) = value.as_object() else {
635        return Vec::new();
636    };
637
638    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
639        return enum_values
640            .iter()
641            .filter(|v| !v.is_null())
642            .cloned()
643            .collect();
644    }
645
646    if let Some(const_value) = obj.get("const") {
647        if const_value.is_null() {
648            return Vec::new();
649        }
650        return vec![const_value.clone()];
651    }
652
653    Vec::new()
654}
655
656fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
657    match properties.get_mut(key) {
658        None => {
659            properties.insert(key.to_string(), value.clone());
660        }
661        Some(existing) => {
662            if existing == value {
663                return;
664            }
665
666            let existing_values = extract_enum_values(existing);
667            let incoming_values = extract_enum_values(value);
668            if incoming_values.is_empty() && existing_values.is_empty() {
669                return;
670            }
671
672            let mut combined = existing_values;
673            for item in incoming_values {
674                if !combined.contains(&item) {
675                    combined.push(item);
676                }
677            }
678
679            if combined.is_empty() {
680                return;
681            }
682
683            if let Some(obj) = existing.as_object_mut() {
684                obj.remove("const");
685                obj.insert("enum".to_string(), Value::Array(combined.clone()));
686                if !obj.contains_key("type")
687                    && let Some(inferred) = infer_type_from_enum(&combined)
688                {
689                    obj.insert("type".to_string(), Value::String(inferred));
690                }
691            }
692        }
693    }
694}
695
696fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
697    let mut merged_props = serde_json::Map::new();
698    let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
699    let mut enum_values: Vec<Value> = Vec::new();
700    let mut type_candidates: Vec<String> = Vec::new();
701
702    for variant in variants {
703        let sanitized = sanitize_for_gemini(root, variant);
704
705        if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
706            type_candidates.push(schema_type.to_string());
707        }
708
709        if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
710            for (key, value) in props {
711                merge_property(&mut merged_props, key, value);
712            }
713        }
714
715        if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
716            let req_set: std::collections::BTreeSet<String> = req
717                .iter()
718                .filter_map(|item| item.as_str().map(|s| s.to_string()))
719                .collect();
720
721            required_intersection = match required_intersection.take() {
722                None => Some(req_set),
723                Some(existing) => Some(
724                    existing
725                        .intersection(&req_set)
726                        .cloned()
727                        .collect::<std::collections::BTreeSet<String>>(),
728                ),
729            };
730        }
731
732        if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
733            for value in values {
734                if value.is_null() {
735                    continue;
736                }
737                if !enum_values.contains(value) {
738                    enum_values.push(value.clone());
739                }
740            }
741        }
742    }
743
744    let schema_type = if !merged_props.is_empty() {
745        "object".to_string()
746    } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
747        inferred
748    } else if let Some(first) = type_candidates.first() {
749        first.clone()
750    } else {
751        "string".to_string()
752    };
753
754    let mut merged = serde_json::Map::new();
755    merged.insert("type".to_string(), Value::String(schema_type));
756
757    if !merged_props.is_empty() {
758        merged.insert("properties".to_string(), Value::Object(merged_props));
759    }
760
761    if let Some(required_set) = required_intersection
762        && !required_set.is_empty()
763    {
764        merged.insert(
765            "required".to_string(),
766            Value::Array(
767                required_set
768                    .into_iter()
769                    .map(Value::String)
770                    .collect::<Vec<_>>(),
771            ),
772        );
773    }
774
775    if !enum_values.is_empty() {
776        merged.insert("enum".to_string(), Value::Array(enum_values));
777    }
778
779    Value::Object(merged)
780}
781
782fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
783    if let Some(resolved) = resolve_ref(root, schema) {
784        return sanitize_for_gemini(root, resolved);
785    }
786
787    let Some(obj) = schema.as_object() else {
788        return schema.clone();
789    };
790
791    if let Some(union) = obj
792        .get("oneOf")
793        .or_else(|| obj.get("anyOf"))
794        .or_else(|| obj.get("allOf"))
795        .and_then(|v| v.as_array())
796    {
797        return merge_union_schemas(root, union);
798    }
799
800    let mut out = serde_json::Map::new();
801    for (key, value) in obj {
802        match key.as_str() {
803            "$ref"
804            | "$defs"
805            | "oneOf"
806            | "anyOf"
807            | "allOf"
808            | "const"
809            | "additionalProperties"
810            | "default"
811            | "examples"
812            | "title"
813            | "pattern"
814            | "minLength"
815            | "maxLength"
816            | "minimum"
817            | "maximum"
818            | "minItems"
819            | "maxItems"
820            | "uniqueItems"
821            | "deprecated" => {}
822            "type" => {
823                out.insert("type".to_string(), normalize_type(value));
824            }
825            "properties" => {
826                if let Some(props) = value.as_object() {
827                    let mut sanitized_props = serde_json::Map::new();
828                    for (prop_key, prop_value) in props {
829                        sanitized_props
830                            .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
831                    }
832                    out.insert("properties".to_string(), Value::Object(sanitized_props));
833                }
834            }
835            "items" => {
836                out.insert("items".to_string(), sanitize_for_gemini(root, value));
837            }
838            "enum" => {
839                let values = value
840                    .as_array()
841                    .map(|items| {
842                        items
843                            .iter()
844                            .filter(|v| !v.is_null())
845                            .cloned()
846                            .collect::<Vec<_>>()
847                    })
848                    .unwrap_or_default();
849                out.insert("enum".to_string(), Value::Array(values));
850            }
851            _ => {
852                out.insert(key.clone(), sanitize_for_gemini(root, value));
853            }
854        }
855    }
856
857    if let Some(const_value) = obj.get("const")
858        && !const_value.is_null()
859    {
860        out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
861        if !out.contains_key("type")
862            && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
863        {
864            out.insert("type".to_string(), Value::String(inferred));
865        }
866    }
867
868    if out.get("type") == Some(&Value::String("object".to_string()))
869        && !out.contains_key("properties")
870    {
871        out.insert(
872            "properties".to_string(),
873            Value::Object(serde_json::Map::new()),
874        );
875    }
876
877    if !out.contains_key("type") {
878        if out.contains_key("properties") {
879            out.insert("type".to_string(), Value::String("object".to_string()));
880        } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
881            && let Some(inferred) = infer_type_from_enum(enum_values)
882        {
883            out.insert("type".to_string(), Value::String(inferred));
884        }
885    }
886
887    Value::Object(out)
888}
889
890fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
891    if let Some(prop_map_orig) = property_value.as_object() {
892        let mut simplified_prop = prop_map_orig.clone();
893
894        // Remove 'additionalProperties' as Gemini doesn't support it
895        if simplified_prop.remove("additionalProperties").is_some() {
896            debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
897        }
898
899        // Simplify 'type' field (handle arrays like ["string", "null"])
900        if let Some(type_val) = simplified_prop.get_mut("type") {
901            if let Some(type_array) = type_val.as_array() {
902                if let Some(primary_type) = type_array
903                    .iter()
904                    .find_map(|v| if v.is_null() { None } else { v.as_str() })
905                {
906                    *type_val = serde_json::Value::String(primary_type.to_string());
907                } else {
908                    warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
909                    *type_val = serde_json::Value::String("string".to_string());
910                }
911            } else if !type_val.is_string() {
912                warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
913                *type_val = serde_json::Value::String("string".to_string());
914            }
915            // If it's already a simple string, do nothing.
916        }
917
918        // Fix integer format if necessary
919        if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
920            && let Some(format_val) = simplified_prop.get_mut("format")
921            && format_val.as_str() == Some("uint64")
922        {
923            *format_val = serde_json::Value::String("int64".to_string());
924            // Optionally remove minimum if Gemini doesn't support it with int64
925            // simplified_prop.remove("minimum");
926        }
927
928        // For string types, Gemini only supports 'enum' and 'date-time' formats
929        if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
930            let should_remove_format = simplified_prop
931                .get("format")
932                .and_then(|f| f.as_str())
933                .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
934
935            if should_remove_format
936                && let Some(format_val) = simplified_prop.remove("format")
937                && let Some(format_str) = format_val.as_str()
938            {
939                debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
940            }
941
942            // Also remove other string validation fields that might not be supported
943            if simplified_prop.remove("minLength").is_some() {
944                debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
945            }
946            if simplified_prop.remove("maxLength").is_some() {
947                debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
948            }
949            if simplified_prop.remove("pattern").is_some() {
950                debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
951            }
952        }
953
954        // Recursively simplify 'items' if this is an array type
955        if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
956            && let Some(items_val) = simplified_prop.get_mut("items")
957        {
958            *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
959        }
960
961        // Recursively simplify nested 'properties' if this is an object type
962        if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
963            && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
964        {
965            let simplified_nested_props: serde_json::Map<String, Value> = props
966                .iter()
967                .map(|(nested_key, nested_value)| {
968                    (
969                        nested_key.clone(),
970                        simplify_property_schema(
971                            &format!("{key}.{nested_key}"),
972                            tool_name,
973                            nested_value,
974                        ),
975                    )
976                })
977                .collect();
978            *props = simplified_nested_props;
979        }
980
981        serde_json::Value::Object(simplified_prop)
982    } else {
983        warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
984        property_value.clone() // Return original if not an object
985    }
986}
987
988fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
989    let function_declarations = tools
990        .into_iter()
991        .map(|tool| {
992            let root_schema = tool.input_schema.as_value();
993            let summary = tool.input_schema.summary();
994            let schema_type = if summary.schema_type.is_empty() {
995                "object".to_string()
996            } else {
997                summary.schema_type
998            };
999
1000            // Simplify properties schema for Gemini using the helper function
1001            let simplified_properties = summary
1002                .properties
1003                .iter()
1004                .map(|(key, value)| {
1005                    let sanitized = sanitize_for_gemini(root_schema, value);
1006                    (
1007                        key.clone(),
1008                        simplify_property_schema(key, &tool.name, &sanitized),
1009                    )
1010                })
1011                .collect();
1012
1013            // Construct the parameters object using the specific struct
1014            let parameters = GeminiParameterSchema {
1015                schema_type,                       // Use schema_type field (usually "object")
1016                properties: simplified_properties, // Use simplified properties
1017                required: summary.required,        // Use required field
1018            };
1019
1020            GeminiFunctionDeclaration {
1021                name: tool.name,
1022                description: tool.description,
1023                parameters,
1024            }
1025        })
1026        .collect();
1027
1028    vec![GeminiTool {
1029        function_declarations,
1030    }]
1031}
1032
1033fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1034    // Log prompt feedback if present
1035    if let Some(feedback) = &response.prompt_feedback
1036        && let Some(reason) = &feedback.block_reason
1037    {
1038        let details = format!(
1039            "Prompt blocked due to {:?}. Safety ratings: {:?}",
1040            reason, feedback.safety_ratings
1041        );
1042        warn!(target: "gemini::convert_response", "{}", details);
1043        // Return the specific RequestBlocked error
1044        return Err(ApiError::RequestBlocked {
1045            provider: "google".to_string(), // Assuming "google" is the provider name
1046            details,
1047        });
1048    }
1049
1050    // Check candidates *after* checking for prompt blocking
1051    let candidates = if let Some(cands) = response.candidates {
1052        if cands.is_empty() {
1053            // If it was blocked, the previous check should have caught it.
1054            // So, this means no candidates were generated for other reasons.
1055            warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1056            // Use NoChoices error here
1057            return Err(ApiError::NoChoices {
1058                provider: "google".to_string(),
1059            });
1060        }
1061        cands // Return the non-empty vector
1062    } else {
1063        warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1064        // Use NoChoices error here as well
1065        return Err(ApiError::NoChoices {
1066            provider: "google".to_string(),
1067        });
1068    };
1069
1070    // For simplicity, still taking the first candidate. Multi-candidate handling could be added.
1071    // Access candidates safely since we've checked it's not None or empty.
1072    let candidate = &candidates[0];
1073
1074    // Log finish reason and safety ratings if present
1075    if let Some(reason) = &candidate.finish_reason {
1076        match reason {
1077            GeminiFinishReason::Stop => { /* Normal completion */ }
1078            GeminiFinishReason::MaxTokens => {
1079                warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1080            }
1081            GeminiFinishReason::Safety => {
1082                warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1083                // Consider returning an error or modifying the response based on safety ratings
1084            }
1085            GeminiFinishReason::Recitation => {
1086                warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1087            }
1088            GeminiFinishReason::MalformedFunctionCall => {
1089                warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1090            }
1091            _ => {
1092                info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1093            }
1094        }
1095    }
1096
1097    // Log usage metadata if present
1098    if let Some(usage) = &response.usage_metadata {
1099        debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1100               usage.prompt, usage.candidates, usage.total);
1101    }
1102
1103    let content: Vec<AssistantContent> = candidate
1104        .content // GeminiContentResponse
1105        .parts   // Vec<GeminiResponsePart> (struct)
1106        .iter()
1107        .filter_map(|part| { // part is &GeminiResponsePart (struct)
1108            // Check if this is a thought part first
1109            if part.thought {
1110                debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1111                // For thought parts, extract text content and create a Thought block
1112                if let GeminiResponsePartData::Text { text } = &part.data {
1113                    Some(AssistantContent::Thought {
1114                        thought: ThoughtContent::Simple {
1115                            text: text.clone(),
1116                        },
1117                    })
1118                } else {
1119                    warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1120                    None
1121                }
1122            } else {
1123                // Regular (non-thought) content processing
1124                match &part.data {
1125                    GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1126                        text: text.clone(),
1127                    }),
1128                    GeminiResponsePartData::InlineData { inline_data } => {
1129                        warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1130                        Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1131                    }
1132                    GeminiResponsePartData::FunctionCall { function_call } => {
1133                        Some(AssistantContent::ToolCall {
1134                            tool_call: steer_tools::ToolCall {
1135                                id: uuid::Uuid::new_v4().to_string(), // Generate a synthetic ID
1136                                name: function_call.name.clone(),
1137                                parameters: function_call.args.clone(),
1138                            },
1139                            thought_signature: part
1140                                .thought_signature
1141                                .clone()
1142                                .map(ThoughtSignature::new),
1143                        })
1144                    }
1145                    GeminiResponsePartData::FileData { file_data } => {
1146                        warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1147                        Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1148                    }
1149                     GeminiResponsePartData::ExecutableCode { executable_code } => {
1150                         info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1151                              executable_code.language);
1152                         Some(AssistantContent::Text {
1153                             text: format!(
1154                                 "```{}
1155{}
1156```",
1157                                 executable_code.language.to_lowercase(),
1158                                 executable_code.code
1159                             ),
1160                         })
1161                     }
1162                }
1163            }
1164        })
1165        .collect();
1166
1167    Ok(CompletionResponse { content })
1168}
1169
1170#[async_trait]
1171impl Provider for GeminiClient {
1172    fn name(&self) -> &'static str {
1173        "google"
1174    }
1175
1176    async fn complete(
1177        &self,
1178        model_id: &ModelId,
1179        messages: Vec<AppMessage>,
1180        system: Option<SystemContext>,
1181        tools: Option<Vec<ToolSchema>>,
1182        _call_options: Option<ModelParameters>,
1183        token: CancellationToken,
1184    ) -> Result<CompletionResponse, ApiError> {
1185        let model_name = &model_id.id; // Use the model ID string
1186        let url = format!(
1187            "{}/models/{}:generateContent?key={}",
1188            GEMINI_API_BASE, model_name, self.api_key
1189        );
1190
1191        let gemini_contents = convert_messages(messages)?;
1192
1193        let system_instruction = system
1194            .and_then(|context| context.render())
1195            .map(|instructions| GeminiSystemInstruction {
1196                parts: vec![GeminiRequestPart::Text { text: instructions }],
1197            });
1198
1199        let gemini_tools = tools.map(convert_tools);
1200
1201        // Derive generation config from call options, respecting model/catalog settings
1202        let (temperature, top_p, max_output_tokens) = {
1203            let opts = _call_options.as_ref();
1204            (
1205                opts.and_then(|o| o.temperature).or(Some(1.0)),
1206                opts.and_then(|o| o.top_p).or(Some(0.95)),
1207                opts.and_then(|o| o.max_tokens)
1208                    .map(|v| v as i32)
1209                    .or(Some(65536)),
1210            )
1211        };
1212        let thinking_config = _call_options
1213            .as_ref()
1214            .and_then(|o| o.thinking_config)
1215            .and_then(|tc| {
1216                if tc.enabled {
1217                    Some(GeminiThinkingConfig {
1218                        include_thoughts: tc.include_thoughts,
1219                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1220                    })
1221                } else {
1222                    None
1223                }
1224            });
1225
1226        let request = GeminiRequest {
1227            contents: gemini_contents,
1228            system_instruction,
1229            tools: gemini_tools,
1230            generation_config: Some(GeminiGenerationConfig {
1231                max_output_tokens,
1232                temperature,
1233                top_p,
1234                thinking_config,
1235                ..Default::default()
1236            }),
1237        };
1238
1239        let response = tokio::select! {
1240            biased;
1241            () = token.cancelled() => {
1242                debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1243                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1244            }
1245            res = self.client.post(&url).json(&request).send() => {
1246                res.map_err(ApiError::Network)?
1247            }
1248        };
1249        let status = response.status();
1250
1251        if status != StatusCode::OK {
1252            let error_text = response.text().await.map_err(ApiError::Network)?;
1253            error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1254            return Err(match status.as_u16() {
1255                401 | 403 => ApiError::AuthenticationFailed {
1256                    provider: self.name().to_string(),
1257                    details: error_text,
1258                },
1259                429 => ApiError::RateLimited {
1260                    provider: self.name().to_string(),
1261                    details: error_text,
1262                },
1263                400 | 404 => {
1264                    error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1265                    ApiError::InvalidRequest {
1266                        provider: self.name().to_string(),
1267                        details: error_text,
1268                    }
1269                } // 404 might mean invalid model
1270                500..=599 => ApiError::ServerError {
1271                    provider: self.name().to_string(),
1272                    status_code: status.as_u16(),
1273                    details: error_text,
1274                },
1275                _ => ApiError::Unknown {
1276                    provider: self.name().to_string(),
1277                    details: error_text,
1278                },
1279            });
1280        }
1281
1282        let response_text = response.text().await.map_err(ApiError::Network)?;
1283
1284        match serde_json::from_str::<GeminiResponse>(&response_text) {
1285            Ok(gemini_response) => {
1286                convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1287                    provider: self.name().to_string(),
1288                    details: e.to_string(),
1289                })
1290            }
1291            Err(e) => {
1292                error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1293                Err(ApiError::ResponseParsingError {
1294                    provider: self.name().to_string(),
1295                    details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1296                })
1297            }
1298        }
1299    }
1300
1301    async fn stream_complete(
1302        &self,
1303        model_id: &ModelId,
1304        messages: Vec<AppMessage>,
1305        system: Option<SystemContext>,
1306        tools: Option<Vec<ToolSchema>>,
1307        _call_options: Option<ModelParameters>,
1308        token: CancellationToken,
1309    ) -> Result<CompletionStream, ApiError> {
1310        let model_name = &model_id.id;
1311        let url = format!(
1312            "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1313            GEMINI_API_BASE, model_name, self.api_key
1314        );
1315
1316        let gemini_contents = convert_messages(messages)?;
1317
1318        let system_instruction = system
1319            .and_then(|context| context.render())
1320            .map(|instructions| GeminiSystemInstruction {
1321                parts: vec![GeminiRequestPart::Text { text: instructions }],
1322            });
1323
1324        let gemini_tools = tools.map(convert_tools);
1325
1326        let (temperature, top_p, max_output_tokens) = {
1327            let opts = _call_options.as_ref();
1328            (
1329                opts.and_then(|o| o.temperature).or(Some(1.0)),
1330                opts.and_then(|o| o.top_p).or(Some(0.95)),
1331                opts.and_then(|o| o.max_tokens)
1332                    .map(|v| v as i32)
1333                    .or(Some(65536)),
1334            )
1335        };
1336        let thinking_config = _call_options
1337            .as_ref()
1338            .and_then(|o| o.thinking_config)
1339            .and_then(|tc| {
1340                if tc.enabled {
1341                    Some(GeminiThinkingConfig {
1342                        include_thoughts: tc.include_thoughts,
1343                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1344                    })
1345                } else {
1346                    None
1347                }
1348            });
1349
1350        let request = GeminiRequest {
1351            contents: gemini_contents,
1352            system_instruction,
1353            tools: gemini_tools,
1354            generation_config: Some(GeminiGenerationConfig {
1355                max_output_tokens,
1356                temperature,
1357                top_p,
1358                thinking_config,
1359                ..Default::default()
1360            }),
1361        };
1362
1363        let response = tokio::select! {
1364            biased;
1365            () = token.cancelled() => {
1366                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1367            }
1368            res = self.client.post(&url).json(&request).send() => {
1369                res.map_err(ApiError::Network)?
1370            }
1371        };
1372
1373        let status = response.status();
1374        if status != StatusCode::OK {
1375            let error_text = response.text().await.map_err(ApiError::Network)?;
1376            error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1377            return Err(match status.as_u16() {
1378                401 | 403 => ApiError::AuthenticationFailed {
1379                    provider: self.name().to_string(),
1380                    details: error_text,
1381                },
1382                429 => ApiError::RateLimited {
1383                    provider: self.name().to_string(),
1384                    details: error_text,
1385                },
1386                400 | 404 => ApiError::InvalidRequest {
1387                    provider: self.name().to_string(),
1388                    details: error_text,
1389                },
1390                500..=599 => ApiError::ServerError {
1391                    provider: self.name().to_string(),
1392                    status_code: status.as_u16(),
1393                    details: error_text,
1394                },
1395                _ => ApiError::Unknown {
1396                    provider: self.name().to_string(),
1397                    details: error_text,
1398                },
1399            });
1400        }
1401
1402        let byte_stream = response.bytes_stream();
1403        let sse_stream = parse_sse_stream(byte_stream);
1404
1405        Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1406    }
1407}
1408
1409impl GeminiClient {
1410    fn convert_gemini_stream(
1411        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1412        + Unpin
1413        + Send
1414        + 'static,
1415        token: CancellationToken,
1416    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1417        async_stream::stream! {
1418            let mut content: Vec<AssistantContent> = Vec::new();
1419            loop {
1420                if token.is_cancelled() {
1421                    yield StreamChunk::Error(StreamError::Cancelled);
1422                    break;
1423                }
1424
1425                let event_result = tokio::select! {
1426                    biased;
1427                    () = token.cancelled() => {
1428                        yield StreamChunk::Error(StreamError::Cancelled);
1429                        break;
1430                    }
1431                    event = sse_stream.next() => event
1432                };
1433
1434                let Some(event_result) = event_result else {
1435                    let content = std::mem::take(&mut content);
1436                    yield StreamChunk::MessageComplete(CompletionResponse { content });
1437                    break;
1438                };
1439
1440                let event = match event_result {
1441                    Ok(e) => e,
1442                    Err(e) => {
1443                        yield StreamChunk::Error(StreamError::SseParse(e));
1444                        break;
1445                    }
1446                };
1447
1448                let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1449                    Ok(c) => c,
1450                    Err(e) => {
1451                        debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1452                        continue;
1453                    }
1454                };
1455
1456                if let Some(candidates) = chunk.candidates {
1457                    for candidate in candidates {
1458                        for part in candidate.content.parts {
1459                            let GeminiResponsePart {
1460                                thought,
1461                                thought_signature,
1462                                data,
1463                            } = part;
1464                            let thought_signature =
1465                                thought_signature.map(ThoughtSignature::new);
1466
1467                            if thought {
1468                                if let GeminiResponsePartData::Text { text } = data {
1469                                    match content.last_mut() {
1470                                        Some(AssistantContent::Thought {
1471                                            thought: ThoughtContent::Simple { text: buf },
1472                                        }) => buf.push_str(&text),
1473                                        _ => {
1474                                            content.push(AssistantContent::Thought {
1475                                                thought: ThoughtContent::Simple { text: text.clone() },
1476                                            });
1477                                        }
1478                                    }
1479                                    yield StreamChunk::ThinkingDelta(text);
1480                                }
1481                            } else {
1482                                match data {
1483                                    GeminiResponsePartData::Text { text } => {
1484                                        match content.last_mut() {
1485                                            Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1486                                            _ => {
1487                                                content.push(AssistantContent::Text { text: text.clone() });
1488                                            }
1489                                        }
1490                                        yield StreamChunk::TextDelta(text);
1491                                    }
1492                                    GeminiResponsePartData::FunctionCall { function_call } => {
1493                                        let id = uuid::Uuid::new_v4().to_string();
1494                                        content.push(AssistantContent::ToolCall {
1495                                            tool_call: steer_tools::ToolCall {
1496                                                id: id.clone(),
1497                                                name: function_call.name.clone(),
1498                                                parameters: function_call.args.clone(),
1499                                            },
1500                                            thought_signature,
1501                                        });
1502                                        yield StreamChunk::ToolUseStart {
1503                                            id: id.clone(),
1504                                            name: function_call.name,
1505                                        };
1506                                        yield StreamChunk::ToolUseInputDelta {
1507                                            id,
1508                                            delta: function_call.args.to_string(),
1509                                        };
1510                                    }
1511                                    _ => {}
1512                                }
1513                            }
1514                        }
1515                    }
1516                }
1517            }
1518        }
1519    }
1520}
1521
1522#[cfg(test)]
1523mod tests {
1524    use super::*;
1525    use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1526    use serde_json::json;
1527
1528    #[test]
1529    fn test_simplify_property_schema_removes_additional_properties() {
1530        let property_value = json!({
1531            "type": "object",
1532            "properties": {
1533                "name": {"type": "string"}
1534            },
1535            "additionalProperties": false
1536        });
1537
1538        let expected = json!({
1539            "type": "object",
1540            "properties": {
1541                "name": {"type": "string"}
1542            }
1543        });
1544
1545        let result = simplify_property_schema("testProp", "testTool", &property_value);
1546        assert_eq!(result, expected);
1547    }
1548
1549    #[test]
1550    fn test_simplify_property_schema_removes_unsupported_string_formats() {
1551        let property_value = json!({
1552            "type": "string",
1553            "format": "uri",
1554            "minLength": 1,
1555            "maxLength": 100,
1556            "pattern": "^https://"
1557        });
1558
1559        let expected = json!({
1560            "type": "string"
1561        });
1562
1563        let result = simplify_property_schema("urlProp", "testTool", &property_value);
1564        assert_eq!(result, expected);
1565    }
1566
1567    #[test]
1568    fn test_simplify_property_schema_keeps_supported_string_formats() {
1569        let property_value = json!({
1570            "type": "string",
1571            "format": "date-time"
1572        });
1573
1574        let expected = json!({
1575            "type": "string",
1576            "format": "date-time"
1577        });
1578
1579        let result = simplify_property_schema("dateProp", "testTool", &property_value);
1580        assert_eq!(result, expected);
1581    }
1582
1583    #[test]
1584    fn test_simplify_property_schema_handles_array_types() {
1585        let property_value = json!({
1586            "type": ["string", "null"],
1587            "format": "email"
1588        });
1589
1590        let expected = json!({
1591            "type": "string"
1592        });
1593
1594        let result = simplify_property_schema("emailProp", "testTool", &property_value);
1595        assert_eq!(result, expected);
1596    }
1597
1598    #[test]
1599    fn test_simplify_property_schema_recursively_handles_array_items() {
1600        let property_value = json!({
1601            "type": "array",
1602            "items": {
1603                "type": "object",
1604                "properties": {
1605                    "url": {
1606                        "type": "string",
1607                        "format": "uri"
1608                    }
1609                },
1610                "additionalProperties": false
1611            }
1612        });
1613
1614        let expected = json!({
1615            "type": "array",
1616            "items": {
1617                "type": "object",
1618                "properties": {
1619                    "url": {
1620                        "type": "string"
1621                    }
1622                }
1623            }
1624        });
1625
1626        let result = simplify_property_schema("linksProp", "testTool", &property_value);
1627        assert_eq!(result, expected);
1628    }
1629
1630    #[test]
1631    fn test_simplify_property_schema_recursively_handles_nested_objects() {
1632        let property_value = json!({
1633            "type": "object",
1634            "properties": {
1635                "nested": {
1636                    "type": "object",
1637                    "properties": {
1638                        "field": {
1639                            "type": "string",
1640                            "format": "hostname"
1641                        }
1642                    },
1643                    "additionalProperties": true
1644                }
1645            },
1646            "additionalProperties": false
1647        });
1648
1649        let expected = json!({
1650            "type": "object",
1651            "properties": {
1652                "nested": {
1653                    "type": "object",
1654                    "properties": {
1655                        "field": {
1656                            "type": "string"
1657                        }
1658                    }
1659                }
1660            }
1661        });
1662
1663        let result = simplify_property_schema("complexProp", "testTool", &property_value);
1664        assert_eq!(result, expected);
1665    }
1666
1667    #[test]
1668    fn test_simplify_property_schema_fixes_uint64_format() {
1669        let property_value = json!({
1670            "type": "integer",
1671            "format": "uint64"
1672        });
1673
1674        let expected = json!({
1675            "type": "integer",
1676            "format": "int64"
1677        });
1678
1679        let result = simplify_property_schema("idProp", "testTool", &property_value);
1680        assert_eq!(result, expected);
1681    }
1682
1683    #[test]
1684    fn test_convert_tools_integration() {
1685        use steer_tools::{InputSchema, ToolSchema};
1686
1687        let tool = ToolSchema {
1688            name: "create_issue".to_string(),
1689            display_name: "Create Issue".to_string(),
1690            description: "Create an issue".to_string(),
1691            input_schema: InputSchema::object(
1692                {
1693                    let mut props = serde_json::Map::new();
1694                    props.insert(
1695                        "title".to_string(),
1696                        json!({
1697                            "type": "string",
1698                            "minLength": 1
1699                        }),
1700                    );
1701                    props.insert(
1702                        "links".to_string(),
1703                        json!({
1704                            "type": "array",
1705                            "items": {
1706                                "type": "object",
1707                                "properties": {
1708                                    "url": {
1709                                        "type": "string",
1710                                        "format": "uri"
1711                                    }
1712                                },
1713                                "additionalProperties": false
1714                            }
1715                        }),
1716                    );
1717                    props
1718                },
1719                vec!["title".to_string()],
1720            ),
1721        };
1722
1723        let expected_tools = vec![GeminiTool {
1724            function_declarations: vec![GeminiFunctionDeclaration {
1725                name: "create_issue".to_string(),
1726                description: "Create an issue".to_string(),
1727                parameters: GeminiParameterSchema {
1728                    schema_type: "object".to_string(),
1729                    properties: {
1730                        let mut props = serde_json::Map::new();
1731                        props.insert(
1732                            "title".to_string(),
1733                            json!({
1734                                "type": "string"
1735                            }),
1736                        );
1737                        props.insert(
1738                            "links".to_string(),
1739                            json!({
1740                                "type": "array",
1741                                "items": {
1742                                    "type": "object",
1743                                    "properties": {
1744                                        "url": {
1745                                            "type": "string"
1746                                        }
1747                                    }
1748                                }
1749                            }),
1750                        );
1751                        props
1752                    },
1753                    required: vec!["title".to_string()],
1754                },
1755            }],
1756        }];
1757
1758        let result = convert_tools(vec![tool]);
1759        assert_eq!(result, expected_tools);
1760    }
1761
1762    #[test]
1763    fn test_convert_messages_includes_data_url_image_as_inline_data() {
1764        let messages = vec![Message {
1765            data: MessageData::User {
1766                content: vec![UserContent::Image {
1767                    image: ImageContent {
1768                        mime_type: "image/png".to_string(),
1769                        source: ImageSource::DataUrl {
1770                            data_url: "".to_string(),
1771                        },
1772                        width: None,
1773                        height: None,
1774                        bytes: None,
1775                        sha256: None,
1776                    },
1777                }],
1778            },
1779            timestamp: 1,
1780            id: "msg-image".to_string(),
1781            parent_message_id: None,
1782        }];
1783
1784        let converted = convert_messages(messages).expect("convert messages");
1785        assert_eq!(converted.len(), 1);
1786        let first = &converted[0];
1787        assert_eq!(first.role, "user");
1788        assert_eq!(first.parts.len(), 1);
1789
1790        let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1791        assert_eq!(json["inlineData"]["mimeType"], "image/png");
1792        assert_eq!(json["inlineData"]["data"], "Zm9v");
1793    }
1794
1795    #[test]
1796    fn test_convert_messages_rejects_session_file_image_source() {
1797        let messages = vec![Message {
1798            data: MessageData::User {
1799                content: vec![UserContent::Image {
1800                    image: ImageContent {
1801                        mime_type: "image/png".to_string(),
1802                        source: ImageSource::SessionFile {
1803                            relative_path: "session-1/image.png".to_string(),
1804                        },
1805                        width: None,
1806                        height: None,
1807                        bytes: None,
1808                        sha256: None,
1809                    },
1810                }],
1811            },
1812            timestamp: 1,
1813            id: "msg-image".to_string(),
1814            parent_message_id: None,
1815        }];
1816
1817        let err = convert_messages(messages).expect_err("expected unsupported feature");
1818        match err {
1819            ApiError::UnsupportedFeature {
1820                provider,
1821                feature,
1822                details,
1823            } => {
1824                assert_eq!(provider, "google");
1825                assert_eq!(feature, "image input source");
1826                assert!(details.contains("session file"));
1827            }
1828            other => panic!("Expected UnsupportedFeature, got {other:?}"),
1829        }
1830    }
1831
1832    #[tokio::test]
1833    async fn test_convert_gemini_stream_text_deltas() {
1834        use crate::api::provider::StreamChunk;
1835        use crate::api::sse::SseEvent;
1836        use futures::StreamExt;
1837        use futures::stream;
1838        use std::pin::pin;
1839        use tokio_util::sync::CancellationToken;
1840
1841        let events = vec![
1842            Ok(SseEvent {
1843                event_type: None,
1844                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1845                    .to_string(),
1846                id: None,
1847            }),
1848            Ok(SseEvent {
1849                event_type: None,
1850                data:
1851                    r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1852                        .to_string(),
1853                id: None,
1854            }),
1855        ];
1856
1857        let sse_stream = stream::iter(events);
1858        let token = CancellationToken::new();
1859        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1860
1861        let first_delta = stream.next().await.unwrap();
1862        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1863
1864        let second_delta = stream.next().await.unwrap();
1865        assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1866
1867        let complete = stream.next().await.unwrap();
1868        assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1869    }
1870
1871    #[tokio::test]
1872    async fn test_convert_gemini_stream_with_thinking() {
1873        use crate::api::provider::StreamChunk;
1874        use crate::api::sse::SseEvent;
1875        use futures::StreamExt;
1876        use futures::stream;
1877        use std::pin::pin;
1878        use tokio_util::sync::CancellationToken;
1879
1880        let events = vec![
1881            Ok(SseEvent {
1882                event_type: None,
1883                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1884                id: None,
1885            }),
1886            Ok(SseEvent {
1887                event_type: None,
1888                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1889                id: None,
1890            }),
1891        ];
1892
1893        let sse_stream = stream::iter(events);
1894        let token = CancellationToken::new();
1895        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1896
1897        let thinking_delta = stream.next().await.unwrap();
1898        assert!(
1899            matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
1900        );
1901
1902        let text_delta = stream.next().await.unwrap();
1903        assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
1904
1905        let complete = stream.next().await.unwrap();
1906        if let StreamChunk::MessageComplete(response) = complete {
1907            assert_eq!(response.content.len(), 2);
1908            assert!(matches!(
1909                &response.content[0],
1910                AssistantContent::Thought { .. }
1911            ));
1912            assert!(matches!(
1913                &response.content[1],
1914                AssistantContent::Text { .. }
1915            ));
1916        } else {
1917            panic!("Expected MessageComplete");
1918        }
1919    }
1920
1921    #[tokio::test]
1922    async fn test_convert_gemini_stream_with_function_call() {
1923        use crate::api::provider::StreamChunk;
1924        use crate::api::sse::SseEvent;
1925        use futures::StreamExt;
1926        use futures::stream;
1927        use std::pin::pin;
1928        use tokio_util::sync::CancellationToken;
1929
1930        let events = vec![
1931            Ok(SseEvent {
1932                event_type: None,
1933                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
1934                id: None,
1935            }),
1936        ];
1937
1938        let sse_stream = stream::iter(events);
1939        let token = CancellationToken::new();
1940        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1941
1942        let tool_start = stream.next().await.unwrap();
1943        assert!(
1944            matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
1945        );
1946
1947        let tool_input = stream.next().await.unwrap();
1948        assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
1949
1950        let complete = stream.next().await.unwrap();
1951        if let StreamChunk::MessageComplete(response) = complete {
1952            assert_eq!(response.content.len(), 1);
1953            if let AssistantContent::ToolCall {
1954                tool_call,
1955                thought_signature,
1956            } = &response.content[0]
1957            {
1958                assert_eq!(tool_call.name, "get_weather");
1959                assert_eq!(
1960                    thought_signature.as_ref().map(|sig| sig.as_str()),
1961                    Some("sig_123")
1962                );
1963            } else {
1964                panic!("Expected ToolCall");
1965            }
1966        } else {
1967            panic!("Expected MessageComplete");
1968        }
1969    }
1970
1971    #[tokio::test]
1972    async fn test_convert_gemini_stream_cancellation() {
1973        use crate::api::error::StreamError;
1974        use crate::api::provider::StreamChunk;
1975        use crate::api::sse::SseEvent;
1976        use futures::StreamExt;
1977        use futures::stream;
1978        use std::pin::pin;
1979        use tokio_util::sync::CancellationToken;
1980
1981        let events = vec![Ok(SseEvent {
1982            event_type: None,
1983            data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1984                .to_string(),
1985            id: None,
1986        })];
1987
1988        let sse_stream = stream::iter(events);
1989        let token = CancellationToken::new();
1990        token.cancel();
1991
1992        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1993
1994        let cancelled = stream.next().await.unwrap();
1995        assert!(matches!(
1996            cancelled,
1997            StreamChunk::Error(StreamError::Cancelled)
1998        ));
1999    }
2000
2001    #[tokio::test]
2002    #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2003    async fn test_stream_complete_real_api() {
2004        use crate::api::Provider;
2005        use crate::api::provider::StreamChunk;
2006        use crate::app::conversation::{Message, MessageData, UserContent};
2007        use futures::StreamExt;
2008        use tokio_util::sync::CancellationToken;
2009
2010        dotenvy::dotenv().ok();
2011        let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2012        let client = GeminiClient::new(api_key);
2013
2014        let message = Message {
2015            data: MessageData::User {
2016                content: vec![UserContent::Text {
2017                    text: "Say exactly: Hello".to_string(),
2018                }],
2019            },
2020            timestamp: chrono::Utc::now().timestamp_millis() as u64,
2021            id: "test-msg".to_string(),
2022            parent_message_id: None,
2023        };
2024
2025        let model_id = ModelId::new(
2026            crate::config::provider::google(),
2027            "gemini-2.5-flash-preview-04-17",
2028        );
2029        let token = CancellationToken::new();
2030
2031        let mut stream = client
2032            .stream_complete(&model_id, vec![message], None, None, None, token)
2033            .await
2034            .expect("stream_complete should succeed");
2035
2036        let mut got_text_delta = false;
2037        let mut got_message_complete = false;
2038        let mut accumulated_text = String::new();
2039
2040        while let Some(chunk) = stream.next().await {
2041            match chunk {
2042                StreamChunk::TextDelta(text) => {
2043                    got_text_delta = true;
2044                    accumulated_text.push_str(&text);
2045                }
2046                StreamChunk::MessageComplete(response) => {
2047                    got_message_complete = true;
2048                    assert!(!response.content.is_empty());
2049                }
2050                StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2051                _ => {}
2052            }
2053        }
2054
2055        assert!(got_text_delta, "Should receive at least one TextDelta");
2056        assert!(
2057            got_message_complete,
2058            "Should receive MessageComplete at the end"
2059        );
2060        assert!(
2061            accumulated_text.to_lowercase().contains("hello"),
2062            "Response should contain 'hello', got: {accumulated_text}"
2063        );
2064    }
2065}