Skip to main content

steer_core/api/gemini/
client.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::app::SystemContext;
13use crate::app::conversation::{
14    AssistantContent, Message as AppMessage, ThoughtContent, ThoughtSignature, ToolResult,
15    UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
21
22#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone for potential future use
23struct GeminiBlob {
24    #[serde(rename = "mimeType")]
25    mime_type: String,
26    data: String, // Assuming base64 encoded data
27}
28
29#[derive(Debug, Deserialize, Serialize, Clone)] // Added Serialize and Clone
30struct GeminiFileData {
31    #[serde(rename = "mimeType")]
32    mime_type: String,
33    #[serde(rename = "fileUri")]
34    file_uri: String,
35}
36
37pub struct GeminiClient {
38    api_key: String,
39    client: HttpClient,
40}
41
42impl GeminiClient {
43    pub fn new(api_key: impl Into<String>) -> Self {
44        Self {
45            api_key: api_key.into(),
46            client: HttpClient::new(),
47        }
48    }
49}
50
51#[derive(Debug, Serialize)]
52struct GeminiRequest {
53    contents: Vec<GeminiContent>,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    #[serde(rename = "systemInstruction")]
56    system_instruction: Option<GeminiSystemInstruction>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    tools: Option<Vec<GeminiTool>>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    #[serde(rename = "generationConfig")]
61    generation_config: Option<GeminiGenerationConfig>,
62}
63
64#[derive(Debug, Serialize, Default, Clone)]
65struct GeminiGenerationConfig {
66    #[serde(skip_serializing_if = "Option::is_none")]
67    #[serde(rename = "stopSequences")]
68    stop_sequences: Option<Vec<String>>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    #[serde(rename = "responseMimeType")]
71    response_mime_type: Option<GeminiMimeType>,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    #[serde(rename = "candidateCount")]
74    candidate_count: Option<i32>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    #[serde(rename = "maxOutputTokens")]
77    max_output_tokens: Option<i32>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    temperature: Option<f32>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    #[serde(rename = "topP")]
82    top_p: Option<f32>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    #[serde(rename = "topK")]
85    top_k: Option<i32>,
86    #[serde(skip_serializing_if = "Option::is_none")]
87    #[serde(rename = "thinkingConfig")]
88    thinking_config: Option<GeminiThinkingConfig>,
89}
90
91#[derive(Debug, Serialize, Default, Clone)]
92struct GeminiThinkingConfig {
93    #[serde(skip_serializing_if = "Option::is_none")]
94    #[serde(rename = "includeThoughts")]
95    include_thoughts: Option<bool>,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    #[serde(rename = "thinkingBudget")]
98    thinking_budget: Option<i32>,
99}
100
101#[expect(dead_code)]
102#[derive(Debug, Serialize, Clone)]
103#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
104enum GeminiMimeType {
105    MimeTypeUnspecified,
106    TextPlain,
107    ApplicationJson,
108}
109
110#[derive(Debug, Serialize)]
111struct GeminiSystemInstruction {
112    parts: Vec<GeminiRequestPart>,
113}
114
115#[derive(Debug, Serialize)]
116struct GeminiContent {
117    role: String,
118    parts: Vec<GeminiRequestPart>,
119}
120
121// Enum for parts used ONLY in requests
122#[derive(Debug, Serialize)]
123#[serde(untagged)]
124enum GeminiRequestPart {
125    Text {
126        text: String,
127    },
128    #[serde(rename = "functionCall")]
129    FunctionCall {
130        #[serde(rename = "functionCall")]
131        function_call: GeminiFunctionCall, // Used for model turns in history
132        #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
133        thought_signature: Option<String>,
134    },
135    #[serde(rename = "functionResponse")]
136    FunctionResponse {
137        #[serde(rename = "functionResponse")]
138        function_response: GeminiFunctionResponse, // Used for function/tool turns
139    },
140}
141
142// Enum for parts received ONLY in responses
143#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum GeminiResponsePartData {
146    Text {
147        text: String,
148    },
149    #[serde(rename = "inlineData")]
150    InlineData {
151        #[serde(rename = "inlineData")]
152        inline_data: GeminiBlob,
153    },
154    #[serde(rename = "functionCall")]
155    FunctionCall {
156        #[serde(rename = "functionCall")]
157        function_call: GeminiFunctionCall,
158    },
159    #[serde(rename = "fileData")]
160    FileData {
161        #[serde(rename = "fileData")]
162        file_data: GeminiFileData,
163    },
164    #[serde(rename = "executableCode")]
165    ExecutableCode {
166        #[serde(rename = "executableCode")]
167        executable_code: GeminiExecutableCode,
168    },
169    // Add other variants back here if needed
170}
171
172// 2. Change GeminiResponsePart to a struct
173#[derive(Debug, Deserialize)]
174struct GeminiResponsePart {
175    #[serde(default)] // Defaults to false if missing
176    thought: bool,
177    #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
178    thought_signature: Option<String>,
179
180    #[serde(flatten)] // Look for data fields directly in this struct's JSON
181    data: GeminiResponsePartData,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185struct GeminiFunctionCall {
186    name: String,
187    args: Value,
188}
189
190#[derive(Debug, Serialize, PartialEq)]
191struct GeminiTool {
192    #[serde(rename = "functionDeclarations")]
193    function_declarations: Vec<GeminiFunctionDeclaration>,
194}
195
196#[derive(Debug, Serialize, PartialEq)]
197struct GeminiFunctionDeclaration {
198    name: String,
199    description: String,
200    parameters: GeminiParameterSchema,
201}
202
203#[derive(Debug, Serialize, PartialEq)]
204struct GeminiParameterSchema {
205    #[serde(rename = "type")]
206    schema_type: String, // Typically "object"
207    properties: serde_json::Map<String, Value>,
208    required: Vec<String>,
209}
210
211#[derive(Debug, Deserialize)]
212struct GeminiResponse {
213    #[serde(rename = "candidates")]
214    #[serde(skip_serializing_if = "Option::is_none")]
215    candidates: Option<Vec<GeminiCandidate>>,
216    #[serde(rename = "promptFeedback")]
217    #[serde(skip_serializing_if = "Option::is_none")]
218    prompt_feedback: Option<GeminiPromptFeedback>,
219    #[serde(rename = "usageMetadata")]
220    #[serde(skip_serializing_if = "Option::is_none")]
221    usage_metadata: Option<GeminiUsageMetadata>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GeminiCandidate {
226    content: GeminiContentResponse,
227    #[serde(rename = "finishReason")]
228    #[serde(skip_serializing_if = "Option::is_none")]
229    finish_reason: Option<GeminiFinishReason>,
230    #[serde(rename = "safetyRatings")]
231    #[serde(skip_serializing_if = "Option::is_none")]
232    safety_ratings: Option<Vec<GeminiSafetyRating>>,
233    #[serde(rename = "citationMetadata")]
234    #[serde(skip_serializing_if = "Option::is_none")]
235    citation_metadata: Option<GeminiCitationMetadata>,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
240enum GeminiFinishReason {
241    FinishReasonUnspecified,
242    Stop,
243    MaxTokens,
244    Safety,
245    Recitation,
246    Other,
247    #[serde(rename = "TOOL_CODE_ERROR")]
248    ToolCodeError,
249    #[serde(rename = "TOOL_EXECUTION_HALT")]
250    ToolExecutionHalt,
251    MalformedFunctionCall,
252}
253
254#[derive(Debug, Deserialize)]
255struct GeminiPromptFeedback {
256    #[serde(rename = "blockReason")]
257    #[serde(skip_serializing_if = "Option::is_none")]
258    block_reason: Option<GeminiBlockReason>,
259    #[serde(rename = "safetyRatings")]
260    #[serde(skip_serializing_if = "Option::is_none")]
261    safety_ratings: Option<Vec<GeminiSafetyRating>>,
262}
263
264#[derive(Debug, Deserialize)]
265#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
266enum GeminiBlockReason {
267    BlockReasonUnspecified,
268    Safety,
269    Other,
270}
271
272#[derive(Debug, Deserialize)]
273#[expect(dead_code)]
274struct GeminiSafetyRating {
275    category: GeminiHarmCategory,
276    probability: GeminiHarmProbability,
277    #[serde(default)] // Default to false if missing
278    blocked: bool,
279}
280
281#[derive(Debug, Deserialize, Serialize)] // Add Serialize for potential use in SafetySetting
282#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
283#[expect(clippy::enum_variant_names)]
284enum GeminiHarmCategory {
285    HarmCategoryUnspecified,
286    HarmCategoryDerogatory,
287    HarmCategoryToxicity,
288    HarmCategoryViolence,
289    HarmCategorySexual,
290    HarmCategoryMedical,
291    HarmCategoryDangerous,
292    HarmCategoryHarassment,
293    HarmCategoryHateSpeech,
294    HarmCategorySexuallyExplicit,
295    HarmCategoryDangerousContent,
296    HarmCategoryCivicIntegrity,
297}
298
299#[derive(Debug, Deserialize)]
300#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
301enum GeminiHarmProbability {
302    HarmProbabilityUnspecified,
303    Negligible,
304    Low,
305    Medium,
306    High,
307}
308
309#[expect(dead_code)]
310#[derive(Debug, Deserialize)]
311struct GeminiCitationMetadata {
312    #[serde(rename = "citationSources")]
313    #[serde(skip_serializing_if = "Option::is_none")]
314    citation_sources: Option<Vec<GeminiCitationSource>>,
315}
316
317#[expect(dead_code)]
318#[derive(Debug, Deserialize)]
319struct GeminiCitationSource {
320    #[serde(rename = "startIndex")]
321    #[serde(skip_serializing_if = "Option::is_none")]
322    start_index: Option<i32>,
323    #[serde(rename = "endIndex")]
324    #[serde(skip_serializing_if = "Option::is_none")]
325    end_index: Option<i32>,
326    #[serde(skip_serializing_if = "Option::is_none")]
327    uri: Option<String>,
328    #[serde(skip_serializing_if = "Option::is_none")]
329    license: Option<String>,
330}
331
332#[derive(Debug, Deserialize)]
333struct GeminiUsageMetadata {
334    #[serde(rename = "promptTokenCount")]
335    #[serde(skip_serializing_if = "Option::is_none")]
336    prompt: Option<i32>,
337    #[serde(rename = "candidatesTokenCount")]
338    #[serde(skip_serializing_if = "Option::is_none")]
339    candidates: Option<i32>,
340    #[serde(rename = "totalTokenCount")]
341    #[serde(skip_serializing_if = "Option::is_none")]
342    total: Option<i32>,
343}
344
345#[derive(Debug, Serialize, Deserialize)]
346struct GeminiFunctionResponse {
347    name: String,
348    response: GeminiResponseContent,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352struct GeminiResponseContent {
353    content: Value,
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357struct GeminiExecutableCode {
358    language: String, // e.g., PYTHON
359    code: String,
360}
361
362#[derive(Debug, Deserialize)]
363#[expect(dead_code)]
364struct GeminiContentResponse {
365    role: String,
366    parts: Vec<GeminiResponsePart>,
367}
368
369fn convert_messages(messages: Vec<AppMessage>) -> Vec<GeminiContent> {
370    messages
371        .into_iter()
372        .filter_map(|msg| match &msg.data {
373            crate::app::conversation::MessageData::User { content, .. } => {
374                let parts: Vec<GeminiRequestPart> = content
375                    .iter()
376                    .map(|user_content| match user_content {
377                        UserContent::Text { text } => {
378                            GeminiRequestPart::Text { text: text.clone() }
379                        }
380                        UserContent::CommandExecution {
381                            command,
382                            stdout,
383                            stderr,
384                            exit_code,
385                        } => GeminiRequestPart::Text {
386                            text: UserContent::format_command_execution_as_xml(
387                                command, stdout, stderr, *exit_code,
388                            ),
389                        },
390                    })
391                    .collect();
392
393                // Only include the message if it has content after filtering
394                if parts.is_empty() {
395                    None
396                } else {
397                    Some(GeminiContent {
398                        role: "user".to_string(),
399                        parts,
400                    })
401                }
402            }
403            crate::app::conversation::MessageData::Assistant { content, .. } => {
404                let parts: Vec<GeminiRequestPart> = content
405                    .iter()
406                    .filter_map(|assistant_content| match assistant_content {
407                        AssistantContent::Text { text } => {
408                            Some(GeminiRequestPart::Text { text: text.clone() })
409                        }
410                        AssistantContent::ToolCall {
411                            tool_call,
412                            thought_signature,
413                        } => Some(GeminiRequestPart::FunctionCall {
414                            function_call: GeminiFunctionCall {
415                                name: tool_call.name.clone(),
416                                args: tool_call.parameters.clone(),
417                            },
418                            thought_signature: thought_signature
419                                .as_ref()
420                                .map(|signature| signature.as_str().to_string()),
421                        }),
422                        AssistantContent::Thought { .. } => {
423                            // Gemini doesn't send thought blocks in requests
424                            None
425                        }
426                    })
427                    .collect();
428
429                // Always include assistant messages (they should always have content)
430                Some(GeminiContent {
431                    role: "model".to_string(),
432                    parts,
433                })
434            }
435            crate::app::conversation::MessageData::Tool {
436                tool_use_id,
437                result,
438                ..
439            } => {
440                // Convert tool result to function response
441                let result_value = match result {
442                    ToolResult::Error(e) => Value::String(format!("Error: {e}")),
443                    _ => {
444                        // For all other variants, try to serialize as JSON
445                        serde_json::to_value(result)
446                            .unwrap_or_else(|_| Value::String(result.llm_format()))
447                    }
448                };
449
450                let parts = vec![GeminiRequestPart::FunctionResponse {
451                    function_response: GeminiFunctionResponse {
452                        name: tool_use_id.clone(), // Use tool_use_id as function name
453                        response: GeminiResponseContent {
454                            content: result_value,
455                        },
456                    },
457                }];
458
459                Some(GeminiContent {
460                    role: "function".to_string(),
461                    parts,
462                })
463            }
464        })
465        .collect()
466}
467
468fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
469    let reference = schema.get("$ref").and_then(|v| v.as_str())?;
470    let path = reference.strip_prefix("#/")?;
471    let mut current = root;
472    for segment in path.split('/') {
473        current = current.get(segment)?;
474    }
475    Some(current)
476}
477
478fn infer_type_from_enum(values: &[Value]) -> Option<String> {
479    let mut has_string = false;
480    let mut has_number = false;
481    let mut has_bool = false;
482    let mut has_object = false;
483    let mut has_array = false;
484
485    for value in values {
486        match value {
487            Value::String(_) => has_string = true,
488            Value::Number(_) => has_number = true,
489            Value::Bool(_) => has_bool = true,
490            Value::Object(_) => has_object = true,
491            Value::Array(_) => has_array = true,
492            Value::Null => {}
493        }
494    }
495
496    let kind_count = u8::from(has_string)
497        + u8::from(has_number)
498        + u8::from(has_bool)
499        + u8::from(has_object)
500        + u8::from(has_array);
501
502    if kind_count != 1 {
503        return None;
504    }
505
506    if has_string {
507        Some("string".to_string())
508    } else if has_number {
509        Some("number".to_string())
510    } else if has_bool {
511        Some("boolean".to_string())
512    } else if has_object {
513        Some("object".to_string())
514    } else if has_array {
515        Some("array".to_string())
516    } else {
517        None
518    }
519}
520
521fn normalize_type(value: &Value) -> Value {
522    if let Some(type_str) = value.as_str() {
523        return Value::String(type_str.to_string());
524    }
525
526    if let Some(type_array) = value.as_array()
527        && let Some(primary_type) = type_array
528            .iter()
529            .find_map(|v| if v.is_null() { None } else { v.as_str() })
530    {
531        return Value::String(primary_type.to_string());
532    }
533
534    Value::String("string".to_string())
535}
536
537fn extract_enum_values(value: &Value) -> Vec<Value> {
538    let Some(obj) = value.as_object() else {
539        return Vec::new();
540    };
541
542    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
543        return enum_values
544            .iter()
545            .filter(|v| !v.is_null())
546            .cloned()
547            .collect();
548    }
549
550    if let Some(const_value) = obj.get("const") {
551        if const_value.is_null() {
552            return Vec::new();
553        }
554        return vec![const_value.clone()];
555    }
556
557    Vec::new()
558}
559
560fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
561    match properties.get_mut(key) {
562        None => {
563            properties.insert(key.to_string(), value.clone());
564        }
565        Some(existing) => {
566            if existing == value {
567                return;
568            }
569
570            let existing_values = extract_enum_values(existing);
571            let incoming_values = extract_enum_values(value);
572            if incoming_values.is_empty() && existing_values.is_empty() {
573                return;
574            }
575
576            let mut combined = existing_values;
577            for item in incoming_values {
578                if !combined.contains(&item) {
579                    combined.push(item);
580                }
581            }
582
583            if combined.is_empty() {
584                return;
585            }
586
587            if let Some(obj) = existing.as_object_mut() {
588                obj.remove("const");
589                obj.insert("enum".to_string(), Value::Array(combined.clone()));
590                if !obj.contains_key("type")
591                    && let Some(inferred) = infer_type_from_enum(&combined)
592                {
593                    obj.insert("type".to_string(), Value::String(inferred));
594                }
595            }
596        }
597    }
598}
599
600fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
601    let mut merged_props = serde_json::Map::new();
602    let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
603    let mut enum_values: Vec<Value> = Vec::new();
604    let mut type_candidates: Vec<String> = Vec::new();
605
606    for variant in variants {
607        let sanitized = sanitize_for_gemini(root, variant);
608
609        if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
610            type_candidates.push(schema_type.to_string());
611        }
612
613        if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
614            for (key, value) in props {
615                merge_property(&mut merged_props, key, value);
616            }
617        }
618
619        if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
620            let req_set: std::collections::BTreeSet<String> = req
621                .iter()
622                .filter_map(|item| item.as_str().map(|s| s.to_string()))
623                .collect();
624
625            required_intersection = match required_intersection.take() {
626                None => Some(req_set),
627                Some(existing) => Some(
628                    existing
629                        .intersection(&req_set)
630                        .cloned()
631                        .collect::<std::collections::BTreeSet<String>>(),
632                ),
633            };
634        }
635
636        if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
637            for value in values {
638                if value.is_null() {
639                    continue;
640                }
641                if !enum_values.contains(value) {
642                    enum_values.push(value.clone());
643                }
644            }
645        }
646    }
647
648    let schema_type = if !merged_props.is_empty() {
649        "object".to_string()
650    } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
651        inferred
652    } else if let Some(first) = type_candidates.first() {
653        first.clone()
654    } else {
655        "string".to_string()
656    };
657
658    let mut merged = serde_json::Map::new();
659    merged.insert("type".to_string(), Value::String(schema_type));
660
661    if !merged_props.is_empty() {
662        merged.insert("properties".to_string(), Value::Object(merged_props));
663    }
664
665    if let Some(required_set) = required_intersection
666        && !required_set.is_empty()
667    {
668        merged.insert(
669            "required".to_string(),
670            Value::Array(
671                required_set
672                    .into_iter()
673                    .map(Value::String)
674                    .collect::<Vec<_>>(),
675            ),
676        );
677    }
678
679    if !enum_values.is_empty() {
680        merged.insert("enum".to_string(), Value::Array(enum_values));
681    }
682
683    Value::Object(merged)
684}
685
686fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
687    if let Some(resolved) = resolve_ref(root, schema) {
688        return sanitize_for_gemini(root, resolved);
689    }
690
691    let Some(obj) = schema.as_object() else {
692        return schema.clone();
693    };
694
695    if let Some(union) = obj
696        .get("oneOf")
697        .or_else(|| obj.get("anyOf"))
698        .or_else(|| obj.get("allOf"))
699        .and_then(|v| v.as_array())
700    {
701        return merge_union_schemas(root, union);
702    }
703
704    let mut out = serde_json::Map::new();
705    for (key, value) in obj {
706        match key.as_str() {
707            "$ref"
708            | "$defs"
709            | "oneOf"
710            | "anyOf"
711            | "allOf"
712            | "const"
713            | "additionalProperties"
714            | "default"
715            | "examples"
716            | "title"
717            | "pattern"
718            | "minLength"
719            | "maxLength"
720            | "minimum"
721            | "maximum"
722            | "minItems"
723            | "maxItems"
724            | "uniqueItems"
725            | "deprecated" => {}
726            "type" => {
727                out.insert("type".to_string(), normalize_type(value));
728            }
729            "properties" => {
730                if let Some(props) = value.as_object() {
731                    let mut sanitized_props = serde_json::Map::new();
732                    for (prop_key, prop_value) in props {
733                        sanitized_props
734                            .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
735                    }
736                    out.insert("properties".to_string(), Value::Object(sanitized_props));
737                }
738            }
739            "items" => {
740                out.insert("items".to_string(), sanitize_for_gemini(root, value));
741            }
742            "enum" => {
743                let values = value
744                    .as_array()
745                    .map(|items| {
746                        items
747                            .iter()
748                            .filter(|v| !v.is_null())
749                            .cloned()
750                            .collect::<Vec<_>>()
751                    })
752                    .unwrap_or_default();
753                out.insert("enum".to_string(), Value::Array(values));
754            }
755            _ => {
756                out.insert(key.clone(), sanitize_for_gemini(root, value));
757            }
758        }
759    }
760
761    if let Some(const_value) = obj.get("const")
762        && !const_value.is_null()
763    {
764        out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
765        if !out.contains_key("type")
766            && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
767        {
768            out.insert("type".to_string(), Value::String(inferred));
769        }
770    }
771
772    if out.get("type") == Some(&Value::String("object".to_string()))
773        && !out.contains_key("properties")
774    {
775        out.insert(
776            "properties".to_string(),
777            Value::Object(serde_json::Map::new()),
778        );
779    }
780
781    if !out.contains_key("type") {
782        if out.contains_key("properties") {
783            out.insert("type".to_string(), Value::String("object".to_string()));
784        } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
785            && let Some(inferred) = infer_type_from_enum(enum_values)
786        {
787            out.insert("type".to_string(), Value::String(inferred));
788        }
789    }
790
791    Value::Object(out)
792}
793
794fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
795    if let Some(prop_map_orig) = property_value.as_object() {
796        let mut simplified_prop = prop_map_orig.clone();
797
798        // Remove 'additionalProperties' as Gemini doesn't support it
799        if simplified_prop.remove("additionalProperties").is_some() {
800            debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
801        }
802
803        // Simplify 'type' field (handle arrays like ["string", "null"])
804        if let Some(type_val) = simplified_prop.get_mut("type") {
805            if let Some(type_array) = type_val.as_array() {
806                if let Some(primary_type) = type_array
807                    .iter()
808                    .find_map(|v| if v.is_null() { None } else { v.as_str() })
809                {
810                    *type_val = serde_json::Value::String(primary_type.to_string());
811                } else {
812                    warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
813                    *type_val = serde_json::Value::String("string".to_string());
814                }
815            } else if !type_val.is_string() {
816                warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
817                *type_val = serde_json::Value::String("string".to_string());
818            }
819            // If it's already a simple string, do nothing.
820        }
821
822        // Fix integer format if necessary
823        if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
824            && let Some(format_val) = simplified_prop.get_mut("format")
825            && format_val.as_str() == Some("uint64")
826        {
827            *format_val = serde_json::Value::String("int64".to_string());
828            // Optionally remove minimum if Gemini doesn't support it with int64
829            // simplified_prop.remove("minimum");
830        }
831
832        // For string types, Gemini only supports 'enum' and 'date-time' formats
833        if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
834            let should_remove_format = simplified_prop
835                .get("format")
836                .and_then(|f| f.as_str())
837                .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
838
839            if should_remove_format
840                && let Some(format_val) = simplified_prop.remove("format")
841                && let Some(format_str) = format_val.as_str()
842            {
843                debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
844            }
845
846            // Also remove other string validation fields that might not be supported
847            if simplified_prop.remove("minLength").is_some() {
848                debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
849            }
850            if simplified_prop.remove("maxLength").is_some() {
851                debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
852            }
853            if simplified_prop.remove("pattern").is_some() {
854                debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
855            }
856        }
857
858        // Recursively simplify 'items' if this is an array type
859        if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
860            && let Some(items_val) = simplified_prop.get_mut("items")
861        {
862            *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
863        }
864
865        // Recursively simplify nested 'properties' if this is an object type
866        if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
867            && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
868        {
869            let simplified_nested_props: serde_json::Map<String, Value> = props
870                .iter()
871                .map(|(nested_key, nested_value)| {
872                    (
873                        nested_key.clone(),
874                        simplify_property_schema(
875                            &format!("{key}.{nested_key}"),
876                            tool_name,
877                            nested_value,
878                        ),
879                    )
880                })
881                .collect();
882            *props = simplified_nested_props;
883        }
884
885        serde_json::Value::Object(simplified_prop)
886    } else {
887        warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
888        property_value.clone() // Return original if not an object
889    }
890}
891
892fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
893    let function_declarations = tools
894        .into_iter()
895        .map(|tool| {
896            let root_schema = tool.input_schema.as_value();
897            let summary = tool.input_schema.summary();
898            let schema_type = if summary.schema_type.is_empty() {
899                "object".to_string()
900            } else {
901                summary.schema_type
902            };
903
904            // Simplify properties schema for Gemini using the helper function
905            let simplified_properties = summary
906                .properties
907                .iter()
908                .map(|(key, value)| {
909                    let sanitized = sanitize_for_gemini(root_schema, value);
910                    (
911                        key.clone(),
912                        simplify_property_schema(key, &tool.name, &sanitized),
913                    )
914                })
915                .collect();
916
917            // Construct the parameters object using the specific struct
918            let parameters = GeminiParameterSchema {
919                schema_type,                       // Use schema_type field (usually "object")
920                properties: simplified_properties, // Use simplified properties
921                required: summary.required,        // Use required field
922            };
923
924            GeminiFunctionDeclaration {
925                name: tool.name,
926                description: tool.description,
927                parameters,
928            }
929        })
930        .collect();
931
932    vec![GeminiTool {
933        function_declarations,
934    }]
935}
936
937fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
938    // Log prompt feedback if present
939    if let Some(feedback) = &response.prompt_feedback
940        && let Some(reason) = &feedback.block_reason
941    {
942        let details = format!(
943            "Prompt blocked due to {:?}. Safety ratings: {:?}",
944            reason, feedback.safety_ratings
945        );
946        warn!(target: "gemini::convert_response", "{}", details);
947        // Return the specific RequestBlocked error
948        return Err(ApiError::RequestBlocked {
949            provider: "google".to_string(), // Assuming "google" is the provider name
950            details,
951        });
952    }
953
954    // Check candidates *after* checking for prompt blocking
955    let candidates = if let Some(cands) = response.candidates {
956        if cands.is_empty() {
957            // If it was blocked, the previous check should have caught it.
958            // So, this means no candidates were generated for other reasons.
959            warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
960            // Use NoChoices error here
961            return Err(ApiError::NoChoices {
962                provider: "google".to_string(),
963            });
964        }
965        cands // Return the non-empty vector
966    } else {
967        warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
968        // Use NoChoices error here as well
969        return Err(ApiError::NoChoices {
970            provider: "google".to_string(),
971        });
972    };
973
974    // For simplicity, still taking the first candidate. Multi-candidate handling could be added.
975    // Access candidates safely since we've checked it's not None or empty.
976    let candidate = &candidates[0];
977
978    // Log finish reason and safety ratings if present
979    if let Some(reason) = &candidate.finish_reason {
980        match reason {
981            GeminiFinishReason::Stop => { /* Normal completion */ }
982            GeminiFinishReason::MaxTokens => {
983                warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
984            }
985            GeminiFinishReason::Safety => {
986                warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
987                // Consider returning an error or modifying the response based on safety ratings
988            }
989            GeminiFinishReason::Recitation => {
990                warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
991            }
992            GeminiFinishReason::MalformedFunctionCall => {
993                warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
994            }
995            _ => {
996                info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
997            }
998        }
999    }
1000
1001    // Log usage metadata if present
1002    if let Some(usage) = &response.usage_metadata {
1003        debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1004               usage.prompt, usage.candidates, usage.total);
1005    }
1006
1007    let content: Vec<AssistantContent> = candidate
1008        .content // GeminiContentResponse
1009        .parts   // Vec<GeminiResponsePart> (struct)
1010        .iter()
1011        .filter_map(|part| { // part is &GeminiResponsePart (struct)
1012            // Check if this is a thought part first
1013            if part.thought {
1014                debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1015                // For thought parts, extract text content and create a Thought block
1016                if let GeminiResponsePartData::Text { text } = &part.data {
1017                    Some(AssistantContent::Thought {
1018                        thought: ThoughtContent::Simple {
1019                            text: text.clone(),
1020                        },
1021                    })
1022                } else {
1023                    warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1024                    None
1025                }
1026            } else {
1027                // Regular (non-thought) content processing
1028                match &part.data {
1029                    GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1030                        text: text.clone(),
1031                    }),
1032                    GeminiResponsePartData::InlineData { inline_data } => {
1033                        warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1034                        Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1035                    }
1036                    GeminiResponsePartData::FunctionCall { function_call } => {
1037                        Some(AssistantContent::ToolCall {
1038                            tool_call: steer_tools::ToolCall {
1039                                id: uuid::Uuid::new_v4().to_string(), // Generate a synthetic ID
1040                                name: function_call.name.clone(),
1041                                parameters: function_call.args.clone(),
1042                            },
1043                            thought_signature: part
1044                                .thought_signature
1045                                .clone()
1046                                .map(ThoughtSignature::new),
1047                        })
1048                    }
1049                    GeminiResponsePartData::FileData { file_data } => {
1050                        warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1051                        Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1052                    }
1053                     GeminiResponsePartData::ExecutableCode { executable_code } => {
1054                         info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1055                              executable_code.language);
1056                         Some(AssistantContent::Text {
1057                             text: format!(
1058                                 "```{}
1059{}
1060```",
1061                                 executable_code.language.to_lowercase(),
1062                                 executable_code.code
1063                             ),
1064                         })
1065                     }
1066                }
1067            }
1068        })
1069        .collect();
1070
1071    Ok(CompletionResponse { content })
1072}
1073
1074#[async_trait]
1075impl Provider for GeminiClient {
1076    fn name(&self) -> &'static str {
1077        "google"
1078    }
1079
1080    async fn complete(
1081        &self,
1082        model_id: &ModelId,
1083        messages: Vec<AppMessage>,
1084        system: Option<SystemContext>,
1085        tools: Option<Vec<ToolSchema>>,
1086        _call_options: Option<ModelParameters>,
1087        token: CancellationToken,
1088    ) -> Result<CompletionResponse, ApiError> {
1089        let model_name = &model_id.id; // Use the model ID string
1090        let url = format!(
1091            "{}/models/{}:generateContent?key={}",
1092            GEMINI_API_BASE, model_name, self.api_key
1093        );
1094
1095        let gemini_contents = convert_messages(messages);
1096
1097        let system_instruction = system
1098            .and_then(|context| context.render())
1099            .map(|instructions| GeminiSystemInstruction {
1100                parts: vec![GeminiRequestPart::Text { text: instructions }],
1101            });
1102
1103        let gemini_tools = tools.map(convert_tools);
1104
1105        // Derive generation config from call options, respecting model/catalog settings
1106        let (temperature, top_p, max_output_tokens) = {
1107            let opts = _call_options.as_ref();
1108            (
1109                opts.and_then(|o| o.temperature).or(Some(1.0)),
1110                opts.and_then(|o| o.top_p).or(Some(0.95)),
1111                opts.and_then(|o| o.max_tokens)
1112                    .map(|v| v as i32)
1113                    .or(Some(65536)),
1114            )
1115        };
1116        let thinking_config = _call_options
1117            .as_ref()
1118            .and_then(|o| o.thinking_config)
1119            .and_then(|tc| {
1120                if tc.enabled {
1121                    Some(GeminiThinkingConfig {
1122                        include_thoughts: tc.include_thoughts,
1123                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1124                    })
1125                } else {
1126                    None
1127                }
1128            });
1129
1130        let request = GeminiRequest {
1131            contents: gemini_contents,
1132            system_instruction,
1133            tools: gemini_tools,
1134            generation_config: Some(GeminiGenerationConfig {
1135                max_output_tokens,
1136                temperature,
1137                top_p,
1138                thinking_config,
1139                ..Default::default()
1140            }),
1141        };
1142
1143        let response = tokio::select! {
1144            biased;
1145            () = token.cancelled() => {
1146                debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1147                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1148            }
1149            res = self.client.post(&url).json(&request).send() => {
1150                res.map_err(ApiError::Network)?
1151            }
1152        };
1153        let status = response.status();
1154
1155        if status != StatusCode::OK {
1156            let error_text = response.text().await.map_err(ApiError::Network)?;
1157            error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1158            return Err(match status.as_u16() {
1159                401 | 403 => ApiError::AuthenticationFailed {
1160                    provider: self.name().to_string(),
1161                    details: error_text,
1162                },
1163                429 => ApiError::RateLimited {
1164                    provider: self.name().to_string(),
1165                    details: error_text,
1166                },
1167                400 | 404 => {
1168                    error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1169                    ApiError::InvalidRequest {
1170                        provider: self.name().to_string(),
1171                        details: error_text,
1172                    }
1173                } // 404 might mean invalid model
1174                500..=599 => ApiError::ServerError {
1175                    provider: self.name().to_string(),
1176                    status_code: status.as_u16(),
1177                    details: error_text,
1178                },
1179                _ => ApiError::Unknown {
1180                    provider: self.name().to_string(),
1181                    details: error_text,
1182                },
1183            });
1184        }
1185
1186        let response_text = response.text().await.map_err(ApiError::Network)?;
1187
1188        match serde_json::from_str::<GeminiResponse>(&response_text) {
1189            Ok(gemini_response) => {
1190                convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1191                    provider: self.name().to_string(),
1192                    details: e.to_string(),
1193                })
1194            }
1195            Err(e) => {
1196                error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1197                Err(ApiError::ResponseParsingError {
1198                    provider: self.name().to_string(),
1199                    details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1200                })
1201            }
1202        }
1203    }
1204
1205    async fn stream_complete(
1206        &self,
1207        model_id: &ModelId,
1208        messages: Vec<AppMessage>,
1209        system: Option<SystemContext>,
1210        tools: Option<Vec<ToolSchema>>,
1211        _call_options: Option<ModelParameters>,
1212        token: CancellationToken,
1213    ) -> Result<CompletionStream, ApiError> {
1214        let model_name = &model_id.id;
1215        let url = format!(
1216            "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1217            GEMINI_API_BASE, model_name, self.api_key
1218        );
1219
1220        let gemini_contents = convert_messages(messages);
1221
1222        let system_instruction = system
1223            .and_then(|context| context.render())
1224            .map(|instructions| GeminiSystemInstruction {
1225                parts: vec![GeminiRequestPart::Text { text: instructions }],
1226            });
1227
1228        let gemini_tools = tools.map(convert_tools);
1229
1230        let (temperature, top_p, max_output_tokens) = {
1231            let opts = _call_options.as_ref();
1232            (
1233                opts.and_then(|o| o.temperature).or(Some(1.0)),
1234                opts.and_then(|o| o.top_p).or(Some(0.95)),
1235                opts.and_then(|o| o.max_tokens)
1236                    .map(|v| v as i32)
1237                    .or(Some(65536)),
1238            )
1239        };
1240        let thinking_config = _call_options
1241            .as_ref()
1242            .and_then(|o| o.thinking_config)
1243            .and_then(|tc| {
1244                if tc.enabled {
1245                    Some(GeminiThinkingConfig {
1246                        include_thoughts: tc.include_thoughts,
1247                        thinking_budget: tc.budget_tokens.map(|v| v as i32),
1248                    })
1249                } else {
1250                    None
1251                }
1252            });
1253
1254        let request = GeminiRequest {
1255            contents: gemini_contents,
1256            system_instruction,
1257            tools: gemini_tools,
1258            generation_config: Some(GeminiGenerationConfig {
1259                max_output_tokens,
1260                temperature,
1261                top_p,
1262                thinking_config,
1263                ..Default::default()
1264            }),
1265        };
1266
1267        let response = tokio::select! {
1268            biased;
1269            () = token.cancelled() => {
1270                return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1271            }
1272            res = self.client.post(&url).json(&request).send() => {
1273                res.map_err(ApiError::Network)?
1274            }
1275        };
1276
1277        let status = response.status();
1278        if status != StatusCode::OK {
1279            let error_text = response.text().await.map_err(ApiError::Network)?;
1280            error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1281            return Err(match status.as_u16() {
1282                401 | 403 => ApiError::AuthenticationFailed {
1283                    provider: self.name().to_string(),
1284                    details: error_text,
1285                },
1286                429 => ApiError::RateLimited {
1287                    provider: self.name().to_string(),
1288                    details: error_text,
1289                },
1290                400 | 404 => ApiError::InvalidRequest {
1291                    provider: self.name().to_string(),
1292                    details: error_text,
1293                },
1294                500..=599 => ApiError::ServerError {
1295                    provider: self.name().to_string(),
1296                    status_code: status.as_u16(),
1297                    details: error_text,
1298                },
1299                _ => ApiError::Unknown {
1300                    provider: self.name().to_string(),
1301                    details: error_text,
1302                },
1303            });
1304        }
1305
1306        let byte_stream = response.bytes_stream();
1307        let sse_stream = parse_sse_stream(byte_stream);
1308
1309        Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1310    }
1311}
1312
1313impl GeminiClient {
1314    fn convert_gemini_stream(
1315        mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1316        + Unpin
1317        + Send
1318        + 'static,
1319        token: CancellationToken,
1320    ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1321        async_stream::stream! {
1322            let mut content: Vec<AssistantContent> = Vec::new();
1323            loop {
1324                if token.is_cancelled() {
1325                    yield StreamChunk::Error(StreamError::Cancelled);
1326                    break;
1327                }
1328
1329                let event_result = tokio::select! {
1330                    biased;
1331                    () = token.cancelled() => {
1332                        yield StreamChunk::Error(StreamError::Cancelled);
1333                        break;
1334                    }
1335                    event = sse_stream.next() => event
1336                };
1337
1338                let Some(event_result) = event_result else {
1339                    let content = std::mem::take(&mut content);
1340                    yield StreamChunk::MessageComplete(CompletionResponse { content });
1341                    break;
1342                };
1343
1344                let event = match event_result {
1345                    Ok(e) => e,
1346                    Err(e) => {
1347                        yield StreamChunk::Error(StreamError::SseParse(e));
1348                        break;
1349                    }
1350                };
1351
1352                let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1353                    Ok(c) => c,
1354                    Err(e) => {
1355                        debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1356                        continue;
1357                    }
1358                };
1359
1360                if let Some(candidates) = chunk.candidates {
1361                    for candidate in candidates {
1362                        for part in candidate.content.parts {
1363                            let GeminiResponsePart {
1364                                thought,
1365                                thought_signature,
1366                                data,
1367                            } = part;
1368                            let thought_signature =
1369                                thought_signature.map(ThoughtSignature::new);
1370
1371                            if thought {
1372                                if let GeminiResponsePartData::Text { text } = data {
1373                                    match content.last_mut() {
1374                                        Some(AssistantContent::Thought {
1375                                            thought: ThoughtContent::Simple { text: buf },
1376                                        }) => buf.push_str(&text),
1377                                        _ => {
1378                                            content.push(AssistantContent::Thought {
1379                                                thought: ThoughtContent::Simple { text: text.clone() },
1380                                            });
1381                                        }
1382                                    }
1383                                    yield StreamChunk::ThinkingDelta(text);
1384                                }
1385                            } else {
1386                                match data {
1387                                    GeminiResponsePartData::Text { text } => {
1388                                        match content.last_mut() {
1389                                            Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1390                                            _ => {
1391                                                content.push(AssistantContent::Text { text: text.clone() });
1392                                            }
1393                                        }
1394                                        yield StreamChunk::TextDelta(text);
1395                                    }
1396                                    GeminiResponsePartData::FunctionCall { function_call } => {
1397                                        let id = uuid::Uuid::new_v4().to_string();
1398                                        content.push(AssistantContent::ToolCall {
1399                                            tool_call: steer_tools::ToolCall {
1400                                                id: id.clone(),
1401                                                name: function_call.name.clone(),
1402                                                parameters: function_call.args.clone(),
1403                                            },
1404                                            thought_signature,
1405                                        });
1406                                        yield StreamChunk::ToolUseStart {
1407                                            id: id.clone(),
1408                                            name: function_call.name,
1409                                        };
1410                                        yield StreamChunk::ToolUseInputDelta {
1411                                            id,
1412                                            delta: function_call.args.to_string(),
1413                                        };
1414                                    }
1415                                    _ => {}
1416                                }
1417                            }
1418                        }
1419                    }
1420                }
1421            }
1422        }
1423    }
1424}
1425
1426#[cfg(test)]
1427mod tests {
1428    use super::*;
1429    use serde_json::json;
1430
1431    #[test]
1432    fn test_simplify_property_schema_removes_additional_properties() {
1433        let property_value = json!({
1434            "type": "object",
1435            "properties": {
1436                "name": {"type": "string"}
1437            },
1438            "additionalProperties": false
1439        });
1440
1441        let expected = json!({
1442            "type": "object",
1443            "properties": {
1444                "name": {"type": "string"}
1445            }
1446        });
1447
1448        let result = simplify_property_schema("testProp", "testTool", &property_value);
1449        assert_eq!(result, expected);
1450    }
1451
1452    #[test]
1453    fn test_simplify_property_schema_removes_unsupported_string_formats() {
1454        let property_value = json!({
1455            "type": "string",
1456            "format": "uri",
1457            "minLength": 1,
1458            "maxLength": 100,
1459            "pattern": "^https://"
1460        });
1461
1462        let expected = json!({
1463            "type": "string"
1464        });
1465
1466        let result = simplify_property_schema("urlProp", "testTool", &property_value);
1467        assert_eq!(result, expected);
1468    }
1469
1470    #[test]
1471    fn test_simplify_property_schema_keeps_supported_string_formats() {
1472        let property_value = json!({
1473            "type": "string",
1474            "format": "date-time"
1475        });
1476
1477        let expected = json!({
1478            "type": "string",
1479            "format": "date-time"
1480        });
1481
1482        let result = simplify_property_schema("dateProp", "testTool", &property_value);
1483        assert_eq!(result, expected);
1484    }
1485
1486    #[test]
1487    fn test_simplify_property_schema_handles_array_types() {
1488        let property_value = json!({
1489            "type": ["string", "null"],
1490            "format": "email"
1491        });
1492
1493        let expected = json!({
1494            "type": "string"
1495        });
1496
1497        let result = simplify_property_schema("emailProp", "testTool", &property_value);
1498        assert_eq!(result, expected);
1499    }
1500
1501    #[test]
1502    fn test_simplify_property_schema_recursively_handles_array_items() {
1503        let property_value = json!({
1504            "type": "array",
1505            "items": {
1506                "type": "object",
1507                "properties": {
1508                    "url": {
1509                        "type": "string",
1510                        "format": "uri"
1511                    }
1512                },
1513                "additionalProperties": false
1514            }
1515        });
1516
1517        let expected = json!({
1518            "type": "array",
1519            "items": {
1520                "type": "object",
1521                "properties": {
1522                    "url": {
1523                        "type": "string"
1524                    }
1525                }
1526            }
1527        });
1528
1529        let result = simplify_property_schema("linksProp", "testTool", &property_value);
1530        assert_eq!(result, expected);
1531    }
1532
1533    #[test]
1534    fn test_simplify_property_schema_recursively_handles_nested_objects() {
1535        let property_value = json!({
1536            "type": "object",
1537            "properties": {
1538                "nested": {
1539                    "type": "object",
1540                    "properties": {
1541                        "field": {
1542                            "type": "string",
1543                            "format": "hostname"
1544                        }
1545                    },
1546                    "additionalProperties": true
1547                }
1548            },
1549            "additionalProperties": false
1550        });
1551
1552        let expected = json!({
1553            "type": "object",
1554            "properties": {
1555                "nested": {
1556                    "type": "object",
1557                    "properties": {
1558                        "field": {
1559                            "type": "string"
1560                        }
1561                    }
1562                }
1563            }
1564        });
1565
1566        let result = simplify_property_schema("complexProp", "testTool", &property_value);
1567        assert_eq!(result, expected);
1568    }
1569
1570    #[test]
1571    fn test_simplify_property_schema_fixes_uint64_format() {
1572        let property_value = json!({
1573            "type": "integer",
1574            "format": "uint64"
1575        });
1576
1577        let expected = json!({
1578            "type": "integer",
1579            "format": "int64"
1580        });
1581
1582        let result = simplify_property_schema("idProp", "testTool", &property_value);
1583        assert_eq!(result, expected);
1584    }
1585
1586    #[test]
1587    fn test_convert_tools_integration() {
1588        use steer_tools::{InputSchema, ToolSchema};
1589
1590        let tool = ToolSchema {
1591            name: "create_issue".to_string(),
1592            display_name: "Create Issue".to_string(),
1593            description: "Create an issue".to_string(),
1594            input_schema: InputSchema::object(
1595                {
1596                    let mut props = serde_json::Map::new();
1597                    props.insert(
1598                        "title".to_string(),
1599                        json!({
1600                            "type": "string",
1601                            "minLength": 1
1602                        }),
1603                    );
1604                    props.insert(
1605                        "links".to_string(),
1606                        json!({
1607                            "type": "array",
1608                            "items": {
1609                                "type": "object",
1610                                "properties": {
1611                                    "url": {
1612                                        "type": "string",
1613                                        "format": "uri"
1614                                    }
1615                                },
1616                                "additionalProperties": false
1617                            }
1618                        }),
1619                    );
1620                    props
1621                },
1622                vec!["title".to_string()],
1623            ),
1624        };
1625
1626        let expected_tools = vec![GeminiTool {
1627            function_declarations: vec![GeminiFunctionDeclaration {
1628                name: "create_issue".to_string(),
1629                description: "Create an issue".to_string(),
1630                parameters: GeminiParameterSchema {
1631                    schema_type: "object".to_string(),
1632                    properties: {
1633                        let mut props = serde_json::Map::new();
1634                        props.insert(
1635                            "title".to_string(),
1636                            json!({
1637                                "type": "string"
1638                            }),
1639                        );
1640                        props.insert(
1641                            "links".to_string(),
1642                            json!({
1643                                "type": "array",
1644                                "items": {
1645                                    "type": "object",
1646                                    "properties": {
1647                                        "url": {
1648                                            "type": "string"
1649                                        }
1650                                    }
1651                                }
1652                            }),
1653                        );
1654                        props
1655                    },
1656                    required: vec!["title".to_string()],
1657                },
1658            }],
1659        }];
1660
1661        let result = convert_tools(vec![tool]);
1662        assert_eq!(result, expected_tools);
1663    }
1664
1665    #[tokio::test]
1666    async fn test_convert_gemini_stream_text_deltas() {
1667        use crate::api::provider::StreamChunk;
1668        use crate::api::sse::SseEvent;
1669        use futures::StreamExt;
1670        use futures::stream;
1671        use std::pin::pin;
1672        use tokio_util::sync::CancellationToken;
1673
1674        let events = vec![
1675            Ok(SseEvent {
1676                event_type: None,
1677                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1678                    .to_string(),
1679                id: None,
1680            }),
1681            Ok(SseEvent {
1682                event_type: None,
1683                data:
1684                    r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1685                        .to_string(),
1686                id: None,
1687            }),
1688        ];
1689
1690        let sse_stream = stream::iter(events);
1691        let token = CancellationToken::new();
1692        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1693
1694        let first_delta = stream.next().await.unwrap();
1695        assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1696
1697        let second_delta = stream.next().await.unwrap();
1698        assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1699
1700        let complete = stream.next().await.unwrap();
1701        assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1702    }
1703
1704    #[tokio::test]
1705    async fn test_convert_gemini_stream_with_thinking() {
1706        use crate::api::provider::StreamChunk;
1707        use crate::api::sse::SseEvent;
1708        use futures::StreamExt;
1709        use futures::stream;
1710        use std::pin::pin;
1711        use tokio_util::sync::CancellationToken;
1712
1713        let events = vec![
1714            Ok(SseEvent {
1715                event_type: None,
1716                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1717                id: None,
1718            }),
1719            Ok(SseEvent {
1720                event_type: None,
1721                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1722                id: None,
1723            }),
1724        ];
1725
1726        let sse_stream = stream::iter(events);
1727        let token = CancellationToken::new();
1728        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1729
1730        let thinking_delta = stream.next().await.unwrap();
1731        assert!(
1732            matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
1733        );
1734
1735        let text_delta = stream.next().await.unwrap();
1736        assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
1737
1738        let complete = stream.next().await.unwrap();
1739        if let StreamChunk::MessageComplete(response) = complete {
1740            assert_eq!(response.content.len(), 2);
1741            assert!(matches!(
1742                &response.content[0],
1743                AssistantContent::Thought { .. }
1744            ));
1745            assert!(matches!(
1746                &response.content[1],
1747                AssistantContent::Text { .. }
1748            ));
1749        } else {
1750            panic!("Expected MessageComplete");
1751        }
1752    }
1753
1754    #[tokio::test]
1755    async fn test_convert_gemini_stream_with_function_call() {
1756        use crate::api::provider::StreamChunk;
1757        use crate::api::sse::SseEvent;
1758        use futures::StreamExt;
1759        use futures::stream;
1760        use std::pin::pin;
1761        use tokio_util::sync::CancellationToken;
1762
1763        let events = vec![
1764            Ok(SseEvent {
1765                event_type: None,
1766                data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
1767                id: None,
1768            }),
1769        ];
1770
1771        let sse_stream = stream::iter(events);
1772        let token = CancellationToken::new();
1773        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1774
1775        let tool_start = stream.next().await.unwrap();
1776        assert!(
1777            matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
1778        );
1779
1780        let tool_input = stream.next().await.unwrap();
1781        assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
1782
1783        let complete = stream.next().await.unwrap();
1784        if let StreamChunk::MessageComplete(response) = complete {
1785            assert_eq!(response.content.len(), 1);
1786            if let AssistantContent::ToolCall {
1787                tool_call,
1788                thought_signature,
1789            } = &response.content[0]
1790            {
1791                assert_eq!(tool_call.name, "get_weather");
1792                assert_eq!(
1793                    thought_signature.as_ref().map(|sig| sig.as_str()),
1794                    Some("sig_123")
1795                );
1796            } else {
1797                panic!("Expected ToolCall");
1798            }
1799        } else {
1800            panic!("Expected MessageComplete");
1801        }
1802    }
1803
1804    #[tokio::test]
1805    async fn test_convert_gemini_stream_cancellation() {
1806        use crate::api::error::StreamError;
1807        use crate::api::provider::StreamChunk;
1808        use crate::api::sse::SseEvent;
1809        use futures::StreamExt;
1810        use futures::stream;
1811        use std::pin::pin;
1812        use tokio_util::sync::CancellationToken;
1813
1814        let events = vec![Ok(SseEvent {
1815            event_type: None,
1816            data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1817                .to_string(),
1818            id: None,
1819        })];
1820
1821        let sse_stream = stream::iter(events);
1822        let token = CancellationToken::new();
1823        token.cancel();
1824
1825        let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1826
1827        let cancelled = stream.next().await.unwrap();
1828        assert!(matches!(
1829            cancelled,
1830            StreamChunk::Error(StreamError::Cancelled)
1831        ));
1832    }
1833
1834    #[tokio::test]
1835    #[ignore = "Requires GOOGLE_API_KEY environment variable"]
1836    async fn test_stream_complete_real_api() {
1837        use crate::api::Provider;
1838        use crate::api::provider::StreamChunk;
1839        use crate::app::conversation::{Message, MessageData, UserContent};
1840        use futures::StreamExt;
1841        use tokio_util::sync::CancellationToken;
1842
1843        dotenvy::dotenv().ok();
1844        let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
1845        let client = GeminiClient::new(api_key);
1846
1847        let message = Message {
1848            data: MessageData::User {
1849                content: vec![UserContent::Text {
1850                    text: "Say exactly: Hello".to_string(),
1851                }],
1852            },
1853            timestamp: chrono::Utc::now().timestamp_millis() as u64,
1854            id: "test-msg".to_string(),
1855            parent_message_id: None,
1856        };
1857
1858        let model_id = ModelId::new(
1859            crate::config::provider::google(),
1860            "gemini-2.5-flash-preview-04-17",
1861        );
1862        let token = CancellationToken::new();
1863
1864        let mut stream = client
1865            .stream_complete(&model_id, vec![message], None, None, None, token)
1866            .await
1867            .expect("stream_complete should succeed");
1868
1869        let mut got_text_delta = false;
1870        let mut got_message_complete = false;
1871        let mut accumulated_text = String::new();
1872
1873        while let Some(chunk) = stream.next().await {
1874            match chunk {
1875                StreamChunk::TextDelta(text) => {
1876                    got_text_delta = true;
1877                    accumulated_text.push_str(&text);
1878                }
1879                StreamChunk::MessageComplete(response) => {
1880                    got_message_complete = true;
1881                    assert!(!response.content.is_empty());
1882                }
1883                StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
1884                _ => {}
1885            }
1886        }
1887
1888        assert!(got_text_delta, "Should receive at least one TextDelta");
1889        assert!(
1890            got_message_complete,
1891            "Should receive MessageComplete at the end"
1892        );
1893        assert!(
1894            accumulated_text.to_lowercase().contains("hello"),
1895            "Response should contain 'hello', got: {accumulated_text}"
1896        );
1897    }
1898}