1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::app::SystemContext;
13use crate::app::conversation::{
14 AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
15 ToolResult, UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
21
22#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
24 #[serde(rename = "mimeType")]
25 mime_type: String,
26 data: String, }
28
29#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
31 #[serde(rename = "mimeType")]
32 mime_type: String,
33 #[serde(rename = "fileUri")]
34 file_uri: String,
35}
36
37#[expect(dead_code)]
38#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiCodeExecutionResult {
40 outcome: String, }
43
44pub struct GeminiClient {
45 api_key: String,
46 client: HttpClient,
47}
48
49impl GeminiClient {
50 pub fn new(api_key: impl Into<String>) -> Self {
51 Self {
52 api_key: api_key.into(),
53 client: HttpClient::new(),
54 }
55 }
56}
57
58#[derive(Debug, Serialize)]
59struct GeminiRequest {
60 contents: Vec<GeminiContent>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 #[serde(rename = "systemInstruction")]
63 system_instruction: Option<GeminiSystemInstruction>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 tools: Option<Vec<GeminiTool>>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 #[serde(rename = "generationConfig")]
68 generation_config: Option<GeminiGenerationConfig>,
69}
70
71#[derive(Debug, Serialize, Default, Clone)]
72struct GeminiGenerationConfig {
73 #[serde(skip_serializing_if = "Option::is_none")]
74 #[serde(rename = "stopSequences")]
75 stop_sequences: Option<Vec<String>>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 #[serde(rename = "responseMimeType")]
78 response_mime_type: Option<GeminiMimeType>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 #[serde(rename = "candidateCount")]
81 candidate_count: Option<i32>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 #[serde(rename = "maxOutputTokens")]
84 max_output_tokens: Option<i32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 temperature: Option<f32>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 #[serde(rename = "topP")]
89 top_p: Option<f32>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 #[serde(rename = "topK")]
92 top_k: Option<i32>,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 #[serde(rename = "thinkingConfig")]
95 thinking_config: Option<GeminiThinkingConfig>,
96}
97
98#[derive(Debug, Serialize, Default, Clone)]
99struct GeminiThinkingConfig {
100 #[serde(skip_serializing_if = "Option::is_none")]
101 #[serde(rename = "includeThoughts")]
102 include_thoughts: Option<bool>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 #[serde(rename = "thinkingBudget")]
105 thinking_budget: Option<i32>,
106}
107
108#[expect(dead_code)]
109#[derive(Debug, Serialize, Clone)]
110#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
111enum GeminiMimeType {
112 MimeTypeUnspecified,
113 TextPlain,
114 ApplicationJson,
115}
116
117#[derive(Debug, Serialize)]
118struct GeminiSystemInstruction {
119 parts: Vec<GeminiRequestPart>,
120}
121
122#[derive(Debug, Serialize)]
123struct GeminiContent {
124 role: String,
125 parts: Vec<GeminiRequestPart>,
126}
127
128#[derive(Debug, Serialize)]
130#[serde(untagged)]
131enum GeminiRequestPart {
132 Text {
133 text: String,
134 },
135 #[serde(rename = "inlineData")]
136 InlineData {
137 #[serde(rename = "inlineData")]
138 inline_data: GeminiBlob,
139 },
140 #[serde(rename = "fileData")]
141 FileData {
142 #[serde(rename = "fileData")]
143 file_data: GeminiFileData,
144 },
145 #[serde(rename = "functionCall")]
146 FunctionCall {
147 #[serde(rename = "functionCall")]
148 function_call: GeminiFunctionCall, #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
150 thought_signature: Option<String>,
151 },
152 #[serde(rename = "functionResponse")]
153 FunctionResponse {
154 #[serde(rename = "functionResponse")]
155 function_response: GeminiFunctionResponse, },
157}
158
159#[derive(Debug, Deserialize)]
161#[serde(untagged)]
162enum GeminiResponsePartData {
163 Text {
164 text: String,
165 },
166 #[serde(rename = "inlineData")]
167 InlineData {
168 #[serde(rename = "inlineData")]
169 inline_data: GeminiBlob,
170 },
171 #[serde(rename = "functionCall")]
172 FunctionCall {
173 #[serde(rename = "functionCall")]
174 function_call: GeminiFunctionCall,
175 },
176 #[serde(rename = "fileData")]
177 FileData {
178 #[serde(rename = "fileData")]
179 file_data: GeminiFileData,
180 },
181 #[serde(rename = "executableCode")]
182 ExecutableCode {
183 #[serde(rename = "executableCode")]
184 executable_code: GeminiExecutableCode,
185 },
186 }
188
189#[derive(Debug, Deserialize)]
191struct GeminiResponsePart {
192 #[serde(default)] thought: bool,
194 #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
195 thought_signature: Option<String>,
196
197 #[serde(flatten)] data: GeminiResponsePartData,
199}
200
201#[derive(Debug, Serialize, Deserialize)]
202struct GeminiFunctionCall {
203 name: String,
204 args: Value,
205}
206
207#[derive(Debug, Serialize, PartialEq)]
208struct GeminiTool {
209 #[serde(rename = "functionDeclarations")]
210 function_declarations: Vec<GeminiFunctionDeclaration>,
211}
212
213#[derive(Debug, Serialize, PartialEq)]
214struct GeminiFunctionDeclaration {
215 name: String,
216 description: String,
217 parameters: GeminiParameterSchema,
218}
219
220#[derive(Debug, Serialize, PartialEq)]
221struct GeminiParameterSchema {
222 #[serde(rename = "type")]
223 schema_type: String, properties: serde_json::Map<String, Value>,
225 required: Vec<String>,
226}
227
228#[derive(Debug, Deserialize)]
229struct GeminiResponse {
230 #[serde(rename = "candidates")]
231 #[serde(skip_serializing_if = "Option::is_none")]
232 candidates: Option<Vec<GeminiCandidate>>,
233 #[serde(rename = "promptFeedback")]
234 #[serde(skip_serializing_if = "Option::is_none")]
235 prompt_feedback: Option<GeminiPromptFeedback>,
236 #[serde(rename = "usageMetadata")]
237 #[serde(skip_serializing_if = "Option::is_none")]
238 usage_metadata: Option<GeminiUsageMetadata>,
239}
240
241#[derive(Debug, Deserialize)]
242struct GeminiCandidate {
243 content: GeminiContentResponse,
244 #[serde(rename = "finishReason")]
245 #[serde(skip_serializing_if = "Option::is_none")]
246 finish_reason: Option<GeminiFinishReason>,
247 #[serde(rename = "safetyRatings")]
248 #[serde(skip_serializing_if = "Option::is_none")]
249 safety_ratings: Option<Vec<GeminiSafetyRating>>,
250 #[serde(rename = "citationMetadata")]
251 #[serde(skip_serializing_if = "Option::is_none")]
252 citation_metadata: Option<GeminiCitationMetadata>,
253}
254
255#[derive(Debug, Deserialize)]
256#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
257enum GeminiFinishReason {
258 FinishReasonUnspecified,
259 Stop,
260 MaxTokens,
261 Safety,
262 Recitation,
263 Other,
264 #[serde(rename = "TOOL_CODE_ERROR")]
265 ToolCodeError,
266 #[serde(rename = "TOOL_EXECUTION_HALT")]
267 ToolExecutionHalt,
268 MalformedFunctionCall,
269}
270
271#[derive(Debug, Deserialize)]
272struct GeminiPromptFeedback {
273 #[serde(rename = "blockReason")]
274 #[serde(skip_serializing_if = "Option::is_none")]
275 block_reason: Option<GeminiBlockReason>,
276 #[serde(rename = "safetyRatings")]
277 #[serde(skip_serializing_if = "Option::is_none")]
278 safety_ratings: Option<Vec<GeminiSafetyRating>>,
279}
280
281#[derive(Debug, Deserialize)]
282#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
283enum GeminiBlockReason {
284 BlockReasonUnspecified,
285 Safety,
286 Other,
287}
288
289#[derive(Debug, Deserialize)]
290#[expect(dead_code)]
291struct GeminiSafetyRating {
292 category: GeminiHarmCategory,
293 probability: GeminiHarmProbability,
294 #[serde(default)] blocked: bool,
296}
297
298#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
300#[expect(clippy::enum_variant_names)]
301enum GeminiHarmCategory {
302 HarmCategoryUnspecified,
303 HarmCategoryDerogatory,
304 HarmCategoryToxicity,
305 HarmCategoryViolence,
306 HarmCategorySexual,
307 HarmCategoryMedical,
308 HarmCategoryDangerous,
309 HarmCategoryHarassment,
310 HarmCategoryHateSpeech,
311 HarmCategorySexuallyExplicit,
312 HarmCategoryDangerousContent,
313 HarmCategoryCivicIntegrity,
314}
315
316#[derive(Debug, Deserialize)]
317#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
318enum GeminiHarmProbability {
319 HarmProbabilityUnspecified,
320 Negligible,
321 Low,
322 Medium,
323 High,
324}
325
326#[expect(dead_code)]
327#[derive(Debug, Deserialize)]
328struct GeminiCitationMetadata {
329 #[serde(rename = "citationSources")]
330 #[serde(skip_serializing_if = "Option::is_none")]
331 citation_sources: Option<Vec<GeminiCitationSource>>,
332}
333
334#[expect(dead_code)]
335#[derive(Debug, Deserialize)]
336struct GeminiCitationSource {
337 #[serde(rename = "startIndex")]
338 #[serde(skip_serializing_if = "Option::is_none")]
339 start_index: Option<i32>,
340 #[serde(rename = "endIndex")]
341 #[serde(skip_serializing_if = "Option::is_none")]
342 end_index: Option<i32>,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 uri: Option<String>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 license: Option<String>,
347}
348
349#[derive(Debug, Deserialize)]
350struct GeminiUsageMetadata {
351 #[serde(rename = "promptTokenCount")]
352 #[serde(skip_serializing_if = "Option::is_none")]
353 prompt: Option<i32>,
354 #[serde(rename = "candidatesTokenCount")]
355 #[serde(skip_serializing_if = "Option::is_none")]
356 candidates: Option<i32>,
357 #[serde(rename = "totalTokenCount")]
358 #[serde(skip_serializing_if = "Option::is_none")]
359 total: Option<i32>,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363struct GeminiFunctionResponse {
364 name: String,
365 response: GeminiResponseContent,
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369struct GeminiResponseContent {
370 content: Value,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374struct GeminiExecutableCode {
375 language: String, code: String,
377}
378
379#[derive(Debug, Deserialize)]
380#[expect(dead_code)]
381struct GeminiContentResponse {
382 role: String,
383 parts: Vec<GeminiResponsePart>,
384}
385
386fn user_image_to_request_part(
387 image: &crate::app::conversation::ImageContent,
388) -> Result<GeminiRequestPart, ApiError> {
389 match &image.source {
390 ImageSource::DataUrl { data_url } => {
391 let Some((meta, payload)) = data_url.split_once(',') else {
392 return Err(ApiError::InvalidRequest {
393 provider: "google".to_string(),
394 details: "Image data URL is missing ',' separator".to_string(),
395 });
396 };
397
398 let Some(meta_body) = meta.strip_prefix("data:") else {
399 return Err(ApiError::InvalidRequest {
400 provider: "google".to_string(),
401 details: "Image data URL must start with 'data:'".to_string(),
402 });
403 };
404
405 if !meta_body
406 .split(';')
407 .any(|segment| segment.eq_ignore_ascii_case("base64"))
408 {
409 return Err(ApiError::InvalidRequest {
410 provider: "google".to_string(),
411 details: "Gemini image data URLs must be base64 encoded".to_string(),
412 });
413 }
414
415 let mime_type = meta_body
416 .split(';')
417 .next()
418 .filter(|value| !value.trim().is_empty())
419 .unwrap_or("application/octet-stream")
420 .to_string();
421
422 if !mime_type.starts_with("image/") {
423 return Err(ApiError::UnsupportedFeature {
424 provider: "google".to_string(),
425 feature: "image input mime type".to_string(),
426 details: format!("Unsupported image media type '{}'", mime_type),
427 });
428 }
429
430 Ok(GeminiRequestPart::InlineData {
431 inline_data: GeminiBlob {
432 mime_type,
433 data: payload.to_string(),
434 },
435 })
436 }
437 ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
438 file_data: GeminiFileData {
439 mime_type: image.mime_type.clone(),
440 file_uri: url.clone(),
441 },
442 }),
443 ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
444 provider: "google".to_string(),
445 feature: "image input source".to_string(),
446 details: format!(
447 "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
448 relative_path
449 ),
450 }),
451 }
452}
453
454fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
455 let mut converted = Vec::new();
456
457 for msg in messages {
458 match &msg.data {
459 crate::app::conversation::MessageData::User { content, .. } => {
460 let mut parts = Vec::new();
461 for user_content in content {
462 match user_content {
463 UserContent::Text { text } => {
464 parts.push(GeminiRequestPart::Text { text: text.clone() });
465 }
466 UserContent::Image { image } => {
467 parts.push(user_image_to_request_part(image)?);
468 }
469 UserContent::CommandExecution {
470 command,
471 stdout,
472 stderr,
473 exit_code,
474 } => {
475 parts.push(GeminiRequestPart::Text {
476 text: UserContent::format_command_execution_as_xml(
477 command, stdout, stderr, *exit_code,
478 ),
479 });
480 }
481 }
482 }
483
484 if !parts.is_empty() {
485 converted.push(GeminiContent {
486 role: "user".to_string(),
487 parts,
488 });
489 }
490 }
491 crate::app::conversation::MessageData::Assistant { content, .. } => {
492 let mut parts = Vec::new();
493 for assistant_content in content {
494 match assistant_content {
495 AssistantContent::Text { text } => {
496 parts.push(GeminiRequestPart::Text { text: text.clone() });
497 }
498 AssistantContent::Image { image } => {
499 match user_image_to_request_part(image) {
500 Ok(part) => parts.push(part),
501 Err(err) => {
502 debug!(
503 target: "gemini::convert_messages",
504 "Skipping unsupported assistant image block: {}",
505 err
506 );
507 }
508 }
509 }
510 AssistantContent::ToolCall {
511 tool_call,
512 thought_signature,
513 } => parts.push(GeminiRequestPart::FunctionCall {
514 function_call: GeminiFunctionCall {
515 name: tool_call.name.clone(),
516 args: tool_call.parameters.clone(),
517 },
518 thought_signature: thought_signature
519 .as_ref()
520 .map(|signature| signature.as_str().to_string()),
521 }),
522 AssistantContent::Thought { .. } => {
523 }
525 }
526 }
527
528 converted.push(GeminiContent {
529 role: "model".to_string(),
530 parts,
531 });
532 }
533 crate::app::conversation::MessageData::Tool {
534 tool_use_id,
535 result,
536 ..
537 } => {
538 let result_value = match result {
539 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
540 _ => serde_json::to_value(result)
541 .unwrap_or_else(|_| Value::String(result.llm_format())),
542 };
543
544 let parts = vec![GeminiRequestPart::FunctionResponse {
545 function_response: GeminiFunctionResponse {
546 name: tool_use_id.clone(),
547 response: GeminiResponseContent {
548 content: result_value,
549 },
550 },
551 }];
552
553 converted.push(GeminiContent {
554 role: "function".to_string(),
555 parts,
556 });
557 }
558 }
559 }
560
561 Ok(converted)
562}
563
564fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
565 let reference = schema.get("$ref").and_then(|v| v.as_str())?;
566 let path = reference.strip_prefix("#/")?;
567 let mut current = root;
568 for segment in path.split('/') {
569 current = current.get(segment)?;
570 }
571 Some(current)
572}
573
574fn infer_type_from_enum(values: &[Value]) -> Option<String> {
575 let mut has_string = false;
576 let mut has_number = false;
577 let mut has_bool = false;
578 let mut has_object = false;
579 let mut has_array = false;
580
581 for value in values {
582 match value {
583 Value::String(_) => has_string = true,
584 Value::Number(_) => has_number = true,
585 Value::Bool(_) => has_bool = true,
586 Value::Object(_) => has_object = true,
587 Value::Array(_) => has_array = true,
588 Value::Null => {}
589 }
590 }
591
592 let kind_count = u8::from(has_string)
593 + u8::from(has_number)
594 + u8::from(has_bool)
595 + u8::from(has_object)
596 + u8::from(has_array);
597
598 if kind_count != 1 {
599 return None;
600 }
601
602 if has_string {
603 Some("string".to_string())
604 } else if has_number {
605 Some("number".to_string())
606 } else if has_bool {
607 Some("boolean".to_string())
608 } else if has_object {
609 Some("object".to_string())
610 } else if has_array {
611 Some("array".to_string())
612 } else {
613 None
614 }
615}
616
617fn normalize_type(value: &Value) -> Value {
618 if let Some(type_str) = value.as_str() {
619 return Value::String(type_str.to_string());
620 }
621
622 if let Some(type_array) = value.as_array()
623 && let Some(primary_type) = type_array
624 .iter()
625 .find_map(|v| if v.is_null() { None } else { v.as_str() })
626 {
627 return Value::String(primary_type.to_string());
628 }
629
630 Value::String("string".to_string())
631}
632
633fn extract_enum_values(value: &Value) -> Vec<Value> {
634 let Some(obj) = value.as_object() else {
635 return Vec::new();
636 };
637
638 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
639 return enum_values
640 .iter()
641 .filter(|v| !v.is_null())
642 .cloned()
643 .collect();
644 }
645
646 if let Some(const_value) = obj.get("const") {
647 if const_value.is_null() {
648 return Vec::new();
649 }
650 return vec![const_value.clone()];
651 }
652
653 Vec::new()
654}
655
656fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
657 match properties.get_mut(key) {
658 None => {
659 properties.insert(key.to_string(), value.clone());
660 }
661 Some(existing) => {
662 if existing == value {
663 return;
664 }
665
666 let existing_values = extract_enum_values(existing);
667 let incoming_values = extract_enum_values(value);
668 if incoming_values.is_empty() && existing_values.is_empty() {
669 return;
670 }
671
672 let mut combined = existing_values;
673 for item in incoming_values {
674 if !combined.contains(&item) {
675 combined.push(item);
676 }
677 }
678
679 if combined.is_empty() {
680 return;
681 }
682
683 if let Some(obj) = existing.as_object_mut() {
684 obj.remove("const");
685 obj.insert("enum".to_string(), Value::Array(combined.clone()));
686 if !obj.contains_key("type")
687 && let Some(inferred) = infer_type_from_enum(&combined)
688 {
689 obj.insert("type".to_string(), Value::String(inferred));
690 }
691 }
692 }
693 }
694}
695
696fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
697 let mut merged_props = serde_json::Map::new();
698 let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
699 let mut enum_values: Vec<Value> = Vec::new();
700 let mut type_candidates: Vec<String> = Vec::new();
701
702 for variant in variants {
703 let sanitized = sanitize_for_gemini(root, variant);
704
705 if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
706 type_candidates.push(schema_type.to_string());
707 }
708
709 if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
710 for (key, value) in props {
711 merge_property(&mut merged_props, key, value);
712 }
713 }
714
715 if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
716 let req_set: std::collections::BTreeSet<String> = req
717 .iter()
718 .filter_map(|item| item.as_str().map(|s| s.to_string()))
719 .collect();
720
721 required_intersection = match required_intersection.take() {
722 None => Some(req_set),
723 Some(existing) => Some(
724 existing
725 .intersection(&req_set)
726 .cloned()
727 .collect::<std::collections::BTreeSet<String>>(),
728 ),
729 };
730 }
731
732 if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
733 for value in values {
734 if value.is_null() {
735 continue;
736 }
737 if !enum_values.contains(value) {
738 enum_values.push(value.clone());
739 }
740 }
741 }
742 }
743
744 let schema_type = if !merged_props.is_empty() {
745 "object".to_string()
746 } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
747 inferred
748 } else if let Some(first) = type_candidates.first() {
749 first.clone()
750 } else {
751 "string".to_string()
752 };
753
754 let mut merged = serde_json::Map::new();
755 merged.insert("type".to_string(), Value::String(schema_type));
756
757 if !merged_props.is_empty() {
758 merged.insert("properties".to_string(), Value::Object(merged_props));
759 }
760
761 if let Some(required_set) = required_intersection
762 && !required_set.is_empty()
763 {
764 merged.insert(
765 "required".to_string(),
766 Value::Array(
767 required_set
768 .into_iter()
769 .map(Value::String)
770 .collect::<Vec<_>>(),
771 ),
772 );
773 }
774
775 if !enum_values.is_empty() {
776 merged.insert("enum".to_string(), Value::Array(enum_values));
777 }
778
779 Value::Object(merged)
780}
781
782fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
783 if let Some(resolved) = resolve_ref(root, schema) {
784 return sanitize_for_gemini(root, resolved);
785 }
786
787 let Some(obj) = schema.as_object() else {
788 return schema.clone();
789 };
790
791 if let Some(union) = obj
792 .get("oneOf")
793 .or_else(|| obj.get("anyOf"))
794 .or_else(|| obj.get("allOf"))
795 .and_then(|v| v.as_array())
796 {
797 return merge_union_schemas(root, union);
798 }
799
800 let mut out = serde_json::Map::new();
801 for (key, value) in obj {
802 match key.as_str() {
803 "$ref"
804 | "$defs"
805 | "oneOf"
806 | "anyOf"
807 | "allOf"
808 | "const"
809 | "additionalProperties"
810 | "default"
811 | "examples"
812 | "title"
813 | "pattern"
814 | "minLength"
815 | "maxLength"
816 | "minimum"
817 | "maximum"
818 | "minItems"
819 | "maxItems"
820 | "uniqueItems"
821 | "deprecated" => {}
822 "type" => {
823 out.insert("type".to_string(), normalize_type(value));
824 }
825 "properties" => {
826 if let Some(props) = value.as_object() {
827 let mut sanitized_props = serde_json::Map::new();
828 for (prop_key, prop_value) in props {
829 sanitized_props
830 .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
831 }
832 out.insert("properties".to_string(), Value::Object(sanitized_props));
833 }
834 }
835 "items" => {
836 out.insert("items".to_string(), sanitize_for_gemini(root, value));
837 }
838 "enum" => {
839 let values = value
840 .as_array()
841 .map(|items| {
842 items
843 .iter()
844 .filter(|v| !v.is_null())
845 .cloned()
846 .collect::<Vec<_>>()
847 })
848 .unwrap_or_default();
849 out.insert("enum".to_string(), Value::Array(values));
850 }
851 _ => {
852 out.insert(key.clone(), sanitize_for_gemini(root, value));
853 }
854 }
855 }
856
857 if let Some(const_value) = obj.get("const")
858 && !const_value.is_null()
859 {
860 out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
861 if !out.contains_key("type")
862 && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
863 {
864 out.insert("type".to_string(), Value::String(inferred));
865 }
866 }
867
868 if out.get("type") == Some(&Value::String("object".to_string()))
869 && !out.contains_key("properties")
870 {
871 out.insert(
872 "properties".to_string(),
873 Value::Object(serde_json::Map::new()),
874 );
875 }
876
877 if !out.contains_key("type") {
878 if out.contains_key("properties") {
879 out.insert("type".to_string(), Value::String("object".to_string()));
880 } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
881 && let Some(inferred) = infer_type_from_enum(enum_values)
882 {
883 out.insert("type".to_string(), Value::String(inferred));
884 }
885 }
886
887 Value::Object(out)
888}
889
890fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
891 if let Some(prop_map_orig) = property_value.as_object() {
892 let mut simplified_prop = prop_map_orig.clone();
893
894 if simplified_prop.remove("additionalProperties").is_some() {
896 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
897 }
898
899 if let Some(type_val) = simplified_prop.get_mut("type") {
901 if let Some(type_array) = type_val.as_array() {
902 if let Some(primary_type) = type_array
903 .iter()
904 .find_map(|v| if v.is_null() { None } else { v.as_str() })
905 {
906 *type_val = serde_json::Value::String(primary_type.to_string());
907 } else {
908 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
909 *type_val = serde_json::Value::String("string".to_string());
910 }
911 } else if !type_val.is_string() {
912 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
913 *type_val = serde_json::Value::String("string".to_string());
914 }
915 }
917
918 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
920 && let Some(format_val) = simplified_prop.get_mut("format")
921 && format_val.as_str() == Some("uint64")
922 {
923 *format_val = serde_json::Value::String("int64".to_string());
924 }
927
928 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
930 let should_remove_format = simplified_prop
931 .get("format")
932 .and_then(|f| f.as_str())
933 .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
934
935 if should_remove_format
936 && let Some(format_val) = simplified_prop.remove("format")
937 && let Some(format_str) = format_val.as_str()
938 {
939 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
940 }
941
942 if simplified_prop.remove("minLength").is_some() {
944 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
945 }
946 if simplified_prop.remove("maxLength").is_some() {
947 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
948 }
949 if simplified_prop.remove("pattern").is_some() {
950 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
951 }
952 }
953
954 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
956 && let Some(items_val) = simplified_prop.get_mut("items")
957 {
958 *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
959 }
960
961 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
963 && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
964 {
965 let simplified_nested_props: serde_json::Map<String, Value> = props
966 .iter()
967 .map(|(nested_key, nested_value)| {
968 (
969 nested_key.clone(),
970 simplify_property_schema(
971 &format!("{key}.{nested_key}"),
972 tool_name,
973 nested_value,
974 ),
975 )
976 })
977 .collect();
978 *props = simplified_nested_props;
979 }
980
981 serde_json::Value::Object(simplified_prop)
982 } else {
983 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
984 property_value.clone() }
986}
987
988fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
989 let function_declarations = tools
990 .into_iter()
991 .map(|tool| {
992 let root_schema = tool.input_schema.as_value();
993 let summary = tool.input_schema.summary();
994 let schema_type = if summary.schema_type.is_empty() {
995 "object".to_string()
996 } else {
997 summary.schema_type
998 };
999
1000 let simplified_properties = summary
1002 .properties
1003 .iter()
1004 .map(|(key, value)| {
1005 let sanitized = sanitize_for_gemini(root_schema, value);
1006 (
1007 key.clone(),
1008 simplify_property_schema(key, &tool.name, &sanitized),
1009 )
1010 })
1011 .collect();
1012
1013 let parameters = GeminiParameterSchema {
1015 schema_type, properties: simplified_properties, required: summary.required, };
1019
1020 GeminiFunctionDeclaration {
1021 name: tool.name,
1022 description: tool.description,
1023 parameters,
1024 }
1025 })
1026 .collect();
1027
1028 vec![GeminiTool {
1029 function_declarations,
1030 }]
1031}
1032
1033fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1034 if let Some(feedback) = &response.prompt_feedback
1036 && let Some(reason) = &feedback.block_reason
1037 {
1038 let details = format!(
1039 "Prompt blocked due to {:?}. Safety ratings: {:?}",
1040 reason, feedback.safety_ratings
1041 );
1042 warn!(target: "gemini::convert_response", "{}", details);
1043 return Err(ApiError::RequestBlocked {
1045 provider: "google".to_string(), details,
1047 });
1048 }
1049
1050 let candidates = if let Some(cands) = response.candidates {
1052 if cands.is_empty() {
1053 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1056 return Err(ApiError::NoChoices {
1058 provider: "google".to_string(),
1059 });
1060 }
1061 cands } else {
1063 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1064 return Err(ApiError::NoChoices {
1066 provider: "google".to_string(),
1067 });
1068 };
1069
1070 let candidate = &candidates[0];
1073
1074 if let Some(reason) = &candidate.finish_reason {
1076 match reason {
1077 GeminiFinishReason::Stop => { }
1078 GeminiFinishReason::MaxTokens => {
1079 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1080 }
1081 GeminiFinishReason::Safety => {
1082 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1083 }
1085 GeminiFinishReason::Recitation => {
1086 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1087 }
1088 GeminiFinishReason::MalformedFunctionCall => {
1089 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1090 }
1091 _ => {
1092 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1093 }
1094 }
1095 }
1096
1097 if let Some(usage) = &response.usage_metadata {
1099 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1100 usage.prompt, usage.candidates, usage.total);
1101 }
1102
1103 let content: Vec<AssistantContent> = candidate
1104 .content .parts .iter()
1107 .filter_map(|part| { if part.thought {
1110 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1111 if let GeminiResponsePartData::Text { text } = &part.data {
1113 Some(AssistantContent::Thought {
1114 thought: ThoughtContent::Simple {
1115 text: text.clone(),
1116 },
1117 })
1118 } else {
1119 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1120 None
1121 }
1122 } else {
1123 match &part.data {
1125 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1126 text: text.clone(),
1127 }),
1128 GeminiResponsePartData::InlineData { inline_data } => {
1129 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1130 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1131 }
1132 GeminiResponsePartData::FunctionCall { function_call } => {
1133 Some(AssistantContent::ToolCall {
1134 tool_call: steer_tools::ToolCall {
1135 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
1137 parameters: function_call.args.clone(),
1138 },
1139 thought_signature: part
1140 .thought_signature
1141 .clone()
1142 .map(ThoughtSignature::new),
1143 })
1144 }
1145 GeminiResponsePartData::FileData { file_data } => {
1146 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1147 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1148 }
1149 GeminiResponsePartData::ExecutableCode { executable_code } => {
1150 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1151 executable_code.language);
1152 Some(AssistantContent::Text {
1153 text: format!(
1154 "```{}
1155{}
1156```",
1157 executable_code.language.to_lowercase(),
1158 executable_code.code
1159 ),
1160 })
1161 }
1162 }
1163 }
1164 })
1165 .collect();
1166
1167 Ok(CompletionResponse { content })
1168}
1169
1170#[async_trait]
1171impl Provider for GeminiClient {
1172 fn name(&self) -> &'static str {
1173 "google"
1174 }
1175
1176 async fn complete(
1177 &self,
1178 model_id: &ModelId,
1179 messages: Vec<AppMessage>,
1180 system: Option<SystemContext>,
1181 tools: Option<Vec<ToolSchema>>,
1182 _call_options: Option<ModelParameters>,
1183 token: CancellationToken,
1184 ) -> Result<CompletionResponse, ApiError> {
1185 let model_name = &model_id.id; let url = format!(
1187 "{}/models/{}:generateContent?key={}",
1188 GEMINI_API_BASE, model_name, self.api_key
1189 );
1190
1191 let gemini_contents = convert_messages(messages)?;
1192
1193 let system_instruction = system
1194 .and_then(|context| context.render())
1195 .map(|instructions| GeminiSystemInstruction {
1196 parts: vec![GeminiRequestPart::Text { text: instructions }],
1197 });
1198
1199 let gemini_tools = tools.map(convert_tools);
1200
1201 let (temperature, top_p, max_output_tokens) = {
1203 let opts = _call_options.as_ref();
1204 (
1205 opts.and_then(|o| o.temperature).or(Some(1.0)),
1206 opts.and_then(|o| o.top_p).or(Some(0.95)),
1207 opts.and_then(|o| o.max_tokens)
1208 .map(|v| v as i32)
1209 .or(Some(65536)),
1210 )
1211 };
1212 let thinking_config = _call_options
1213 .as_ref()
1214 .and_then(|o| o.thinking_config)
1215 .and_then(|tc| {
1216 if tc.enabled {
1217 Some(GeminiThinkingConfig {
1218 include_thoughts: tc.include_thoughts,
1219 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1220 })
1221 } else {
1222 None
1223 }
1224 });
1225
1226 let request = GeminiRequest {
1227 contents: gemini_contents,
1228 system_instruction,
1229 tools: gemini_tools,
1230 generation_config: Some(GeminiGenerationConfig {
1231 max_output_tokens,
1232 temperature,
1233 top_p,
1234 thinking_config,
1235 ..Default::default()
1236 }),
1237 };
1238
1239 let response = tokio::select! {
1240 biased;
1241 () = token.cancelled() => {
1242 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1243 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1244 }
1245 res = self.client.post(&url).json(&request).send() => {
1246 res.map_err(ApiError::Network)?
1247 }
1248 };
1249 let status = response.status();
1250
1251 if status != StatusCode::OK {
1252 let error_text = response.text().await.map_err(ApiError::Network)?;
1253 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1254 return Err(match status.as_u16() {
1255 401 | 403 => ApiError::AuthenticationFailed {
1256 provider: self.name().to_string(),
1257 details: error_text,
1258 },
1259 429 => ApiError::RateLimited {
1260 provider: self.name().to_string(),
1261 details: error_text,
1262 },
1263 400 | 404 => {
1264 error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1265 ApiError::InvalidRequest {
1266 provider: self.name().to_string(),
1267 details: error_text,
1268 }
1269 } 500..=599 => ApiError::ServerError {
1271 provider: self.name().to_string(),
1272 status_code: status.as_u16(),
1273 details: error_text,
1274 },
1275 _ => ApiError::Unknown {
1276 provider: self.name().to_string(),
1277 details: error_text,
1278 },
1279 });
1280 }
1281
1282 let response_text = response.text().await.map_err(ApiError::Network)?;
1283
1284 match serde_json::from_str::<GeminiResponse>(&response_text) {
1285 Ok(gemini_response) => {
1286 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1287 provider: self.name().to_string(),
1288 details: e.to_string(),
1289 })
1290 }
1291 Err(e) => {
1292 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1293 Err(ApiError::ResponseParsingError {
1294 provider: self.name().to_string(),
1295 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1296 })
1297 }
1298 }
1299 }
1300
1301 async fn stream_complete(
1302 &self,
1303 model_id: &ModelId,
1304 messages: Vec<AppMessage>,
1305 system: Option<SystemContext>,
1306 tools: Option<Vec<ToolSchema>>,
1307 _call_options: Option<ModelParameters>,
1308 token: CancellationToken,
1309 ) -> Result<CompletionStream, ApiError> {
1310 let model_name = &model_id.id;
1311 let url = format!(
1312 "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1313 GEMINI_API_BASE, model_name, self.api_key
1314 );
1315
1316 let gemini_contents = convert_messages(messages)?;
1317
1318 let system_instruction = system
1319 .and_then(|context| context.render())
1320 .map(|instructions| GeminiSystemInstruction {
1321 parts: vec![GeminiRequestPart::Text { text: instructions }],
1322 });
1323
1324 let gemini_tools = tools.map(convert_tools);
1325
1326 let (temperature, top_p, max_output_tokens) = {
1327 let opts = _call_options.as_ref();
1328 (
1329 opts.and_then(|o| o.temperature).or(Some(1.0)),
1330 opts.and_then(|o| o.top_p).or(Some(0.95)),
1331 opts.and_then(|o| o.max_tokens)
1332 .map(|v| v as i32)
1333 .or(Some(65536)),
1334 )
1335 };
1336 let thinking_config = _call_options
1337 .as_ref()
1338 .and_then(|o| o.thinking_config)
1339 .and_then(|tc| {
1340 if tc.enabled {
1341 Some(GeminiThinkingConfig {
1342 include_thoughts: tc.include_thoughts,
1343 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1344 })
1345 } else {
1346 None
1347 }
1348 });
1349
1350 let request = GeminiRequest {
1351 contents: gemini_contents,
1352 system_instruction,
1353 tools: gemini_tools,
1354 generation_config: Some(GeminiGenerationConfig {
1355 max_output_tokens,
1356 temperature,
1357 top_p,
1358 thinking_config,
1359 ..Default::default()
1360 }),
1361 };
1362
1363 let response = tokio::select! {
1364 biased;
1365 () = token.cancelled() => {
1366 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1367 }
1368 res = self.client.post(&url).json(&request).send() => {
1369 res.map_err(ApiError::Network)?
1370 }
1371 };
1372
1373 let status = response.status();
1374 if status != StatusCode::OK {
1375 let error_text = response.text().await.map_err(ApiError::Network)?;
1376 error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1377 return Err(match status.as_u16() {
1378 401 | 403 => ApiError::AuthenticationFailed {
1379 provider: self.name().to_string(),
1380 details: error_text,
1381 },
1382 429 => ApiError::RateLimited {
1383 provider: self.name().to_string(),
1384 details: error_text,
1385 },
1386 400 | 404 => ApiError::InvalidRequest {
1387 provider: self.name().to_string(),
1388 details: error_text,
1389 },
1390 500..=599 => ApiError::ServerError {
1391 provider: self.name().to_string(),
1392 status_code: status.as_u16(),
1393 details: error_text,
1394 },
1395 _ => ApiError::Unknown {
1396 provider: self.name().to_string(),
1397 details: error_text,
1398 },
1399 });
1400 }
1401
1402 let byte_stream = response.bytes_stream();
1403 let sse_stream = parse_sse_stream(byte_stream);
1404
1405 Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1406 }
1407}
1408
1409impl GeminiClient {
1410 fn convert_gemini_stream(
1411 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1412 + Unpin
1413 + Send
1414 + 'static,
1415 token: CancellationToken,
1416 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1417 async_stream::stream! {
1418 let mut content: Vec<AssistantContent> = Vec::new();
1419 loop {
1420 if token.is_cancelled() {
1421 yield StreamChunk::Error(StreamError::Cancelled);
1422 break;
1423 }
1424
1425 let event_result = tokio::select! {
1426 biased;
1427 () = token.cancelled() => {
1428 yield StreamChunk::Error(StreamError::Cancelled);
1429 break;
1430 }
1431 event = sse_stream.next() => event
1432 };
1433
1434 let Some(event_result) = event_result else {
1435 let content = std::mem::take(&mut content);
1436 yield StreamChunk::MessageComplete(CompletionResponse { content });
1437 break;
1438 };
1439
1440 let event = match event_result {
1441 Ok(e) => e,
1442 Err(e) => {
1443 yield StreamChunk::Error(StreamError::SseParse(e));
1444 break;
1445 }
1446 };
1447
1448 let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1449 Ok(c) => c,
1450 Err(e) => {
1451 debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1452 continue;
1453 }
1454 };
1455
1456 if let Some(candidates) = chunk.candidates {
1457 for candidate in candidates {
1458 for part in candidate.content.parts {
1459 let GeminiResponsePart {
1460 thought,
1461 thought_signature,
1462 data,
1463 } = part;
1464 let thought_signature =
1465 thought_signature.map(ThoughtSignature::new);
1466
1467 if thought {
1468 if let GeminiResponsePartData::Text { text } = data {
1469 match content.last_mut() {
1470 Some(AssistantContent::Thought {
1471 thought: ThoughtContent::Simple { text: buf },
1472 }) => buf.push_str(&text),
1473 _ => {
1474 content.push(AssistantContent::Thought {
1475 thought: ThoughtContent::Simple { text: text.clone() },
1476 });
1477 }
1478 }
1479 yield StreamChunk::ThinkingDelta(text);
1480 }
1481 } else {
1482 match data {
1483 GeminiResponsePartData::Text { text } => {
1484 match content.last_mut() {
1485 Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1486 _ => {
1487 content.push(AssistantContent::Text { text: text.clone() });
1488 }
1489 }
1490 yield StreamChunk::TextDelta(text);
1491 }
1492 GeminiResponsePartData::FunctionCall { function_call } => {
1493 let id = uuid::Uuid::new_v4().to_string();
1494 content.push(AssistantContent::ToolCall {
1495 tool_call: steer_tools::ToolCall {
1496 id: id.clone(),
1497 name: function_call.name.clone(),
1498 parameters: function_call.args.clone(),
1499 },
1500 thought_signature,
1501 });
1502 yield StreamChunk::ToolUseStart {
1503 id: id.clone(),
1504 name: function_call.name,
1505 };
1506 yield StreamChunk::ToolUseInputDelta {
1507 id,
1508 delta: function_call.args.to_string(),
1509 };
1510 }
1511 _ => {}
1512 }
1513 }
1514 }
1515 }
1516 }
1517 }
1518 }
1519 }
1520}
1521
1522#[cfg(test)]
1523mod tests {
1524 use super::*;
1525 use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1526 use serde_json::json;
1527
1528 #[test]
1529 fn test_simplify_property_schema_removes_additional_properties() {
1530 let property_value = json!({
1531 "type": "object",
1532 "properties": {
1533 "name": {"type": "string"}
1534 },
1535 "additionalProperties": false
1536 });
1537
1538 let expected = json!({
1539 "type": "object",
1540 "properties": {
1541 "name": {"type": "string"}
1542 }
1543 });
1544
1545 let result = simplify_property_schema("testProp", "testTool", &property_value);
1546 assert_eq!(result, expected);
1547 }
1548
1549 #[test]
1550 fn test_simplify_property_schema_removes_unsupported_string_formats() {
1551 let property_value = json!({
1552 "type": "string",
1553 "format": "uri",
1554 "minLength": 1,
1555 "maxLength": 100,
1556 "pattern": "^https://"
1557 });
1558
1559 let expected = json!({
1560 "type": "string"
1561 });
1562
1563 let result = simplify_property_schema("urlProp", "testTool", &property_value);
1564 assert_eq!(result, expected);
1565 }
1566
1567 #[test]
1568 fn test_simplify_property_schema_keeps_supported_string_formats() {
1569 let property_value = json!({
1570 "type": "string",
1571 "format": "date-time"
1572 });
1573
1574 let expected = json!({
1575 "type": "string",
1576 "format": "date-time"
1577 });
1578
1579 let result = simplify_property_schema("dateProp", "testTool", &property_value);
1580 assert_eq!(result, expected);
1581 }
1582
1583 #[test]
1584 fn test_simplify_property_schema_handles_array_types() {
1585 let property_value = json!({
1586 "type": ["string", "null"],
1587 "format": "email"
1588 });
1589
1590 let expected = json!({
1591 "type": "string"
1592 });
1593
1594 let result = simplify_property_schema("emailProp", "testTool", &property_value);
1595 assert_eq!(result, expected);
1596 }
1597
1598 #[test]
1599 fn test_simplify_property_schema_recursively_handles_array_items() {
1600 let property_value = json!({
1601 "type": "array",
1602 "items": {
1603 "type": "object",
1604 "properties": {
1605 "url": {
1606 "type": "string",
1607 "format": "uri"
1608 }
1609 },
1610 "additionalProperties": false
1611 }
1612 });
1613
1614 let expected = json!({
1615 "type": "array",
1616 "items": {
1617 "type": "object",
1618 "properties": {
1619 "url": {
1620 "type": "string"
1621 }
1622 }
1623 }
1624 });
1625
1626 let result = simplify_property_schema("linksProp", "testTool", &property_value);
1627 assert_eq!(result, expected);
1628 }
1629
1630 #[test]
1631 fn test_simplify_property_schema_recursively_handles_nested_objects() {
1632 let property_value = json!({
1633 "type": "object",
1634 "properties": {
1635 "nested": {
1636 "type": "object",
1637 "properties": {
1638 "field": {
1639 "type": "string",
1640 "format": "hostname"
1641 }
1642 },
1643 "additionalProperties": true
1644 }
1645 },
1646 "additionalProperties": false
1647 });
1648
1649 let expected = json!({
1650 "type": "object",
1651 "properties": {
1652 "nested": {
1653 "type": "object",
1654 "properties": {
1655 "field": {
1656 "type": "string"
1657 }
1658 }
1659 }
1660 }
1661 });
1662
1663 let result = simplify_property_schema("complexProp", "testTool", &property_value);
1664 assert_eq!(result, expected);
1665 }
1666
1667 #[test]
1668 fn test_simplify_property_schema_fixes_uint64_format() {
1669 let property_value = json!({
1670 "type": "integer",
1671 "format": "uint64"
1672 });
1673
1674 let expected = json!({
1675 "type": "integer",
1676 "format": "int64"
1677 });
1678
1679 let result = simplify_property_schema("idProp", "testTool", &property_value);
1680 assert_eq!(result, expected);
1681 }
1682
1683 #[test]
1684 fn test_convert_tools_integration() {
1685 use steer_tools::{InputSchema, ToolSchema};
1686
1687 let tool = ToolSchema {
1688 name: "create_issue".to_string(),
1689 display_name: "Create Issue".to_string(),
1690 description: "Create an issue".to_string(),
1691 input_schema: InputSchema::object(
1692 {
1693 let mut props = serde_json::Map::new();
1694 props.insert(
1695 "title".to_string(),
1696 json!({
1697 "type": "string",
1698 "minLength": 1
1699 }),
1700 );
1701 props.insert(
1702 "links".to_string(),
1703 json!({
1704 "type": "array",
1705 "items": {
1706 "type": "object",
1707 "properties": {
1708 "url": {
1709 "type": "string",
1710 "format": "uri"
1711 }
1712 },
1713 "additionalProperties": false
1714 }
1715 }),
1716 );
1717 props
1718 },
1719 vec!["title".to_string()],
1720 ),
1721 };
1722
1723 let expected_tools = vec![GeminiTool {
1724 function_declarations: vec![GeminiFunctionDeclaration {
1725 name: "create_issue".to_string(),
1726 description: "Create an issue".to_string(),
1727 parameters: GeminiParameterSchema {
1728 schema_type: "object".to_string(),
1729 properties: {
1730 let mut props = serde_json::Map::new();
1731 props.insert(
1732 "title".to_string(),
1733 json!({
1734 "type": "string"
1735 }),
1736 );
1737 props.insert(
1738 "links".to_string(),
1739 json!({
1740 "type": "array",
1741 "items": {
1742 "type": "object",
1743 "properties": {
1744 "url": {
1745 "type": "string"
1746 }
1747 }
1748 }
1749 }),
1750 );
1751 props
1752 },
1753 required: vec!["title".to_string()],
1754 },
1755 }],
1756 }];
1757
1758 let result = convert_tools(vec![tool]);
1759 assert_eq!(result, expected_tools);
1760 }
1761
1762 #[test]
1763 fn test_convert_messages_includes_data_url_image_as_inline_data() {
1764 let messages = vec![Message {
1765 data: MessageData::User {
1766 content: vec![UserContent::Image {
1767 image: ImageContent {
1768 mime_type: "image/png".to_string(),
1769 source: ImageSource::DataUrl {
1770 data_url: "data:image/png;base64,Zm9v".to_string(),
1771 },
1772 width: None,
1773 height: None,
1774 bytes: None,
1775 sha256: None,
1776 },
1777 }],
1778 },
1779 timestamp: 1,
1780 id: "msg-image".to_string(),
1781 parent_message_id: None,
1782 }];
1783
1784 let converted = convert_messages(messages).expect("convert messages");
1785 assert_eq!(converted.len(), 1);
1786 let first = &converted[0];
1787 assert_eq!(first.role, "user");
1788 assert_eq!(first.parts.len(), 1);
1789
1790 let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1791 assert_eq!(json["inlineData"]["mimeType"], "image/png");
1792 assert_eq!(json["inlineData"]["data"], "Zm9v");
1793 }
1794
1795 #[test]
1796 fn test_convert_messages_rejects_session_file_image_source() {
1797 let messages = vec![Message {
1798 data: MessageData::User {
1799 content: vec![UserContent::Image {
1800 image: ImageContent {
1801 mime_type: "image/png".to_string(),
1802 source: ImageSource::SessionFile {
1803 relative_path: "session-1/image.png".to_string(),
1804 },
1805 width: None,
1806 height: None,
1807 bytes: None,
1808 sha256: None,
1809 },
1810 }],
1811 },
1812 timestamp: 1,
1813 id: "msg-image".to_string(),
1814 parent_message_id: None,
1815 }];
1816
1817 let err = convert_messages(messages).expect_err("expected unsupported feature");
1818 match err {
1819 ApiError::UnsupportedFeature {
1820 provider,
1821 feature,
1822 details,
1823 } => {
1824 assert_eq!(provider, "google");
1825 assert_eq!(feature, "image input source");
1826 assert!(details.contains("session file"));
1827 }
1828 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1829 }
1830 }
1831
1832 #[tokio::test]
1833 async fn test_convert_gemini_stream_text_deltas() {
1834 use crate::api::provider::StreamChunk;
1835 use crate::api::sse::SseEvent;
1836 use futures::StreamExt;
1837 use futures::stream;
1838 use std::pin::pin;
1839 use tokio_util::sync::CancellationToken;
1840
1841 let events = vec![
1842 Ok(SseEvent {
1843 event_type: None,
1844 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1845 .to_string(),
1846 id: None,
1847 }),
1848 Ok(SseEvent {
1849 event_type: None,
1850 data:
1851 r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1852 .to_string(),
1853 id: None,
1854 }),
1855 ];
1856
1857 let sse_stream = stream::iter(events);
1858 let token = CancellationToken::new();
1859 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1860
1861 let first_delta = stream.next().await.unwrap();
1862 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1863
1864 let second_delta = stream.next().await.unwrap();
1865 assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1866
1867 let complete = stream.next().await.unwrap();
1868 assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1869 }
1870
1871 #[tokio::test]
1872 async fn test_convert_gemini_stream_with_thinking() {
1873 use crate::api::provider::StreamChunk;
1874 use crate::api::sse::SseEvent;
1875 use futures::StreamExt;
1876 use futures::stream;
1877 use std::pin::pin;
1878 use tokio_util::sync::CancellationToken;
1879
1880 let events = vec![
1881 Ok(SseEvent {
1882 event_type: None,
1883 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1884 id: None,
1885 }),
1886 Ok(SseEvent {
1887 event_type: None,
1888 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1889 id: None,
1890 }),
1891 ];
1892
1893 let sse_stream = stream::iter(events);
1894 let token = CancellationToken::new();
1895 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1896
1897 let thinking_delta = stream.next().await.unwrap();
1898 assert!(
1899 matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
1900 );
1901
1902 let text_delta = stream.next().await.unwrap();
1903 assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
1904
1905 let complete = stream.next().await.unwrap();
1906 if let StreamChunk::MessageComplete(response) = complete {
1907 assert_eq!(response.content.len(), 2);
1908 assert!(matches!(
1909 &response.content[0],
1910 AssistantContent::Thought { .. }
1911 ));
1912 assert!(matches!(
1913 &response.content[1],
1914 AssistantContent::Text { .. }
1915 ));
1916 } else {
1917 panic!("Expected MessageComplete");
1918 }
1919 }
1920
1921 #[tokio::test]
1922 async fn test_convert_gemini_stream_with_function_call() {
1923 use crate::api::provider::StreamChunk;
1924 use crate::api::sse::SseEvent;
1925 use futures::StreamExt;
1926 use futures::stream;
1927 use std::pin::pin;
1928 use tokio_util::sync::CancellationToken;
1929
1930 let events = vec![
1931 Ok(SseEvent {
1932 event_type: None,
1933 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
1934 id: None,
1935 }),
1936 ];
1937
1938 let sse_stream = stream::iter(events);
1939 let token = CancellationToken::new();
1940 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1941
1942 let tool_start = stream.next().await.unwrap();
1943 assert!(
1944 matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
1945 );
1946
1947 let tool_input = stream.next().await.unwrap();
1948 assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
1949
1950 let complete = stream.next().await.unwrap();
1951 if let StreamChunk::MessageComplete(response) = complete {
1952 assert_eq!(response.content.len(), 1);
1953 if let AssistantContent::ToolCall {
1954 tool_call,
1955 thought_signature,
1956 } = &response.content[0]
1957 {
1958 assert_eq!(tool_call.name, "get_weather");
1959 assert_eq!(
1960 thought_signature.as_ref().map(|sig| sig.as_str()),
1961 Some("sig_123")
1962 );
1963 } else {
1964 panic!("Expected ToolCall");
1965 }
1966 } else {
1967 panic!("Expected MessageComplete");
1968 }
1969 }
1970
1971 #[tokio::test]
1972 async fn test_convert_gemini_stream_cancellation() {
1973 use crate::api::error::StreamError;
1974 use crate::api::provider::StreamChunk;
1975 use crate::api::sse::SseEvent;
1976 use futures::StreamExt;
1977 use futures::stream;
1978 use std::pin::pin;
1979 use tokio_util::sync::CancellationToken;
1980
1981 let events = vec![Ok(SseEvent {
1982 event_type: None,
1983 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1984 .to_string(),
1985 id: None,
1986 })];
1987
1988 let sse_stream = stream::iter(events);
1989 let token = CancellationToken::new();
1990 token.cancel();
1991
1992 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1993
1994 let cancelled = stream.next().await.unwrap();
1995 assert!(matches!(
1996 cancelled,
1997 StreamChunk::Error(StreamError::Cancelled)
1998 ));
1999 }
2000
2001 #[tokio::test]
2002 #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2003 async fn test_stream_complete_real_api() {
2004 use crate::api::Provider;
2005 use crate::api::provider::StreamChunk;
2006 use crate::app::conversation::{Message, MessageData, UserContent};
2007 use futures::StreamExt;
2008 use tokio_util::sync::CancellationToken;
2009
2010 dotenvy::dotenv().ok();
2011 let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2012 let client = GeminiClient::new(api_key);
2013
2014 let message = Message {
2015 data: MessageData::User {
2016 content: vec![UserContent::Text {
2017 text: "Say exactly: Hello".to_string(),
2018 }],
2019 },
2020 timestamp: chrono::Utc::now().timestamp_millis() as u64,
2021 id: "test-msg".to_string(),
2022 parent_message_id: None,
2023 };
2024
2025 let model_id = ModelId::new(
2026 crate::config::provider::google(),
2027 "gemini-2.5-flash-preview-04-17",
2028 );
2029 let token = CancellationToken::new();
2030
2031 let mut stream = client
2032 .stream_complete(&model_id, vec![message], None, None, None, token)
2033 .await
2034 .expect("stream_complete should succeed");
2035
2036 let mut got_text_delta = false;
2037 let mut got_message_complete = false;
2038 let mut accumulated_text = String::new();
2039
2040 while let Some(chunk) = stream.next().await {
2041 match chunk {
2042 StreamChunk::TextDelta(text) => {
2043 got_text_delta = true;
2044 accumulated_text.push_str(&text);
2045 }
2046 StreamChunk::MessageComplete(response) => {
2047 got_message_complete = true;
2048 assert!(!response.content.is_empty());
2049 }
2050 StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2051 _ => {}
2052 }
2053 }
2054
2055 assert!(got_text_delta, "Should receive at least one TextDelta");
2056 assert!(
2057 got_message_complete,
2058 "Should receive MessageComplete at the end"
2059 );
2060 assert!(
2061 accumulated_text.to_lowercase().contains("hello"),
2062 "Response should contain 'hello', got: {accumulated_text}"
2063 );
2064 }
2065}