1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11 CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::api::util::map_http_status_to_api_error;
15use crate::app::SystemContext;
16use crate::app::conversation::{
17 AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
18 ToolResult, UserContent,
19};
20use crate::config::model::{ModelId, ModelParameters};
21use steer_tools::ToolSchema;
22
23const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
24
25#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
27 #[serde(rename = "mimeType")]
28 mime_type: String,
29 data: String, }
31
32#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
34 #[serde(rename = "mimeType")]
35 mime_type: String,
36 #[serde(rename = "fileUri")]
37 file_uri: String,
38}
39
40#[expect(dead_code)]
41#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiCodeExecutionResult {
43 outcome: String, }
46
47pub struct GeminiClient {
48 api_key: String,
49 client: HttpClient,
50}
51
52impl GeminiClient {
53 pub fn new(api_key: impl Into<String>) -> Self {
54 Self {
55 api_key: api_key.into(),
56 client: HttpClient::new(),
57 }
58 }
59}
60
61#[derive(Debug, Serialize)]
62struct GeminiRequest {
63 contents: Vec<GeminiContent>,
64 #[serde(skip_serializing_if = "Option::is_none")]
65 #[serde(rename = "systemInstruction")]
66 system_instruction: Option<GeminiSystemInstruction>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 tools: Option<Vec<GeminiTool>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 #[serde(rename = "generationConfig")]
71 generation_config: Option<GeminiGenerationConfig>,
72}
73
74#[derive(Debug, Serialize, Default, Clone)]
75struct GeminiGenerationConfig {
76 #[serde(skip_serializing_if = "Option::is_none")]
77 #[serde(rename = "stopSequences")]
78 stop_sequences: Option<Vec<String>>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 #[serde(rename = "responseMimeType")]
81 response_mime_type: Option<GeminiMimeType>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 #[serde(rename = "candidateCount")]
84 candidate_count: Option<i32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 #[serde(rename = "maxOutputTokens")]
87 max_output_tokens: Option<i32>,
88 #[serde(skip_serializing_if = "Option::is_none")]
89 temperature: Option<f32>,
90 #[serde(skip_serializing_if = "Option::is_none")]
91 #[serde(rename = "topP")]
92 top_p: Option<f32>,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 #[serde(rename = "topK")]
95 top_k: Option<i32>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 #[serde(rename = "thinkingConfig")]
98 thinking_config: Option<GeminiThinkingConfig>,
99}
100
101#[derive(Debug, Serialize, Default, Clone)]
102struct GeminiThinkingConfig {
103 #[serde(skip_serializing_if = "Option::is_none")]
104 #[serde(rename = "includeThoughts")]
105 include_thoughts: Option<bool>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 #[serde(rename = "thinkingBudget")]
108 thinking_budget: Option<i32>,
109}
110
111#[expect(dead_code)]
112#[derive(Debug, Serialize, Clone)]
113#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
114enum GeminiMimeType {
115 MimeTypeUnspecified,
116 TextPlain,
117 ApplicationJson,
118}
119
120#[derive(Debug, Serialize)]
121struct GeminiSystemInstruction {
122 parts: Vec<GeminiRequestPart>,
123}
124
125#[derive(Debug, Serialize)]
126struct GeminiContent {
127 role: String,
128 parts: Vec<GeminiRequestPart>,
129}
130
131#[derive(Debug, Serialize)]
133#[serde(untagged)]
134enum GeminiRequestPart {
135 Text {
136 text: String,
137 },
138 #[serde(rename = "inlineData")]
139 InlineData {
140 #[serde(rename = "inlineData")]
141 inline_data: GeminiBlob,
142 },
143 #[serde(rename = "fileData")]
144 FileData {
145 #[serde(rename = "fileData")]
146 file_data: GeminiFileData,
147 },
148 #[serde(rename = "functionCall")]
149 FunctionCall {
150 #[serde(rename = "functionCall")]
151 function_call: GeminiFunctionCall, #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
153 thought_signature: Option<String>,
154 },
155 #[serde(rename = "functionResponse")]
156 FunctionResponse {
157 #[serde(rename = "functionResponse")]
158 function_response: GeminiFunctionResponse, },
160}
161
162#[derive(Debug, Deserialize)]
164#[serde(untagged)]
165enum GeminiResponsePartData {
166 Text {
167 text: String,
168 },
169 #[serde(rename = "inlineData")]
170 InlineData {
171 #[serde(rename = "inlineData")]
172 inline_data: GeminiBlob,
173 },
174 #[serde(rename = "functionCall")]
175 FunctionCall {
176 #[serde(rename = "functionCall")]
177 function_call: GeminiFunctionCall,
178 },
179 #[serde(rename = "fileData")]
180 FileData {
181 #[serde(rename = "fileData")]
182 file_data: GeminiFileData,
183 },
184 #[serde(rename = "executableCode")]
185 ExecutableCode {
186 #[serde(rename = "executableCode")]
187 executable_code: GeminiExecutableCode,
188 },
189 }
191
192#[derive(Debug, Deserialize)]
194struct GeminiResponsePart {
195 #[serde(default)] thought: bool,
197 #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
198 thought_signature: Option<String>,
199
200 #[serde(flatten)] data: GeminiResponsePartData,
202}
203
204#[derive(Debug, Serialize, Deserialize)]
205struct GeminiFunctionCall {
206 name: String,
207 args: Value,
208}
209
210#[derive(Debug, Serialize, PartialEq)]
211struct GeminiTool {
212 #[serde(rename = "functionDeclarations")]
213 function_declarations: Vec<GeminiFunctionDeclaration>,
214}
215
216#[derive(Debug, Serialize, PartialEq)]
217struct GeminiFunctionDeclaration {
218 name: String,
219 description: String,
220 parameters: GeminiParameterSchema,
221}
222
223#[derive(Debug, Serialize, PartialEq)]
224struct GeminiParameterSchema {
225 #[serde(rename = "type")]
226 schema_type: String, properties: serde_json::Map<String, Value>,
228 required: Vec<String>,
229}
230
231#[derive(Debug, Deserialize)]
232struct GeminiResponse {
233 #[serde(rename = "candidates")]
234 #[serde(skip_serializing_if = "Option::is_none")]
235 candidates: Option<Vec<GeminiCandidate>>,
236 #[serde(rename = "promptFeedback")]
237 #[serde(skip_serializing_if = "Option::is_none")]
238 prompt_feedback: Option<GeminiPromptFeedback>,
239 #[serde(rename = "usageMetadata")]
240 #[serde(skip_serializing_if = "Option::is_none")]
241 usage_metadata: Option<GeminiUsageMetadata>,
242}
243
244#[derive(Debug, Deserialize)]
245struct GeminiCandidate {
246 content: GeminiContentResponse,
247 #[serde(rename = "finishReason")]
248 #[serde(skip_serializing_if = "Option::is_none")]
249 finish_reason: Option<GeminiFinishReason>,
250 #[serde(rename = "safetyRatings")]
251 #[serde(skip_serializing_if = "Option::is_none")]
252 safety_ratings: Option<Vec<GeminiSafetyRating>>,
253 #[serde(rename = "citationMetadata")]
254 #[serde(skip_serializing_if = "Option::is_none")]
255 citation_metadata: Option<GeminiCitationMetadata>,
256}
257
258#[derive(Debug, Deserialize)]
259#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
260enum GeminiFinishReason {
261 FinishReasonUnspecified,
262 Stop,
263 MaxTokens,
264 Safety,
265 Recitation,
266 Other,
267 #[serde(rename = "TOOL_CODE_ERROR")]
268 ToolCodeError,
269 #[serde(rename = "TOOL_EXECUTION_HALT")]
270 ToolExecutionHalt,
271 MalformedFunctionCall,
272}
273
274#[derive(Debug, Deserialize)]
275struct GeminiPromptFeedback {
276 #[serde(rename = "blockReason")]
277 #[serde(skip_serializing_if = "Option::is_none")]
278 block_reason: Option<GeminiBlockReason>,
279 #[serde(rename = "safetyRatings")]
280 #[serde(skip_serializing_if = "Option::is_none")]
281 safety_ratings: Option<Vec<GeminiSafetyRating>>,
282}
283
284#[derive(Debug, Deserialize)]
285#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
286enum GeminiBlockReason {
287 BlockReasonUnspecified,
288 Safety,
289 Other,
290}
291
292#[derive(Debug, Deserialize)]
293#[expect(dead_code)]
294struct GeminiSafetyRating {
295 category: GeminiHarmCategory,
296 probability: GeminiHarmProbability,
297 #[serde(default)] blocked: bool,
299}
300
301#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
303#[expect(clippy::enum_variant_names)]
304enum GeminiHarmCategory {
305 HarmCategoryUnspecified,
306 HarmCategoryDerogatory,
307 HarmCategoryToxicity,
308 HarmCategoryViolence,
309 HarmCategorySexual,
310 HarmCategoryMedical,
311 HarmCategoryDangerous,
312 HarmCategoryHarassment,
313 HarmCategoryHateSpeech,
314 HarmCategorySexuallyExplicit,
315 HarmCategoryDangerousContent,
316 HarmCategoryCivicIntegrity,
317}
318
319#[derive(Debug, Deserialize)]
320#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
321enum GeminiHarmProbability {
322 HarmProbabilityUnspecified,
323 Negligible,
324 Low,
325 Medium,
326 High,
327}
328
329#[expect(dead_code)]
330#[derive(Debug, Deserialize)]
331struct GeminiCitationMetadata {
332 #[serde(rename = "citationSources")]
333 #[serde(skip_serializing_if = "Option::is_none")]
334 citation_sources: Option<Vec<GeminiCitationSource>>,
335}
336
337#[expect(dead_code)]
338#[derive(Debug, Deserialize)]
339struct GeminiCitationSource {
340 #[serde(rename = "startIndex")]
341 #[serde(skip_serializing_if = "Option::is_none")]
342 start_index: Option<i32>,
343 #[serde(rename = "endIndex")]
344 #[serde(skip_serializing_if = "Option::is_none")]
345 end_index: Option<i32>,
346 #[serde(skip_serializing_if = "Option::is_none")]
347 uri: Option<String>,
348 #[serde(skip_serializing_if = "Option::is_none")]
349 license: Option<String>,
350}
351
352#[derive(Debug, Deserialize)]
353struct GeminiUsageMetadata {
354 #[serde(rename = "promptTokenCount")]
355 #[serde(skip_serializing_if = "Option::is_none")]
356 prompt: Option<i32>,
357 #[serde(rename = "candidatesTokenCount")]
358 #[serde(skip_serializing_if = "Option::is_none")]
359 candidates: Option<i32>,
360 #[serde(rename = "totalTokenCount")]
361 #[serde(skip_serializing_if = "Option::is_none")]
362 total: Option<i32>,
363}
364
365#[derive(Debug, Serialize, Deserialize)]
366struct GeminiFunctionResponse {
367 name: String,
368 response: GeminiResponseContent,
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372struct GeminiResponseContent {
373 content: Value,
374}
375
376#[derive(Debug, Serialize, Deserialize)]
377struct GeminiExecutableCode {
378 language: String, code: String,
380}
381
382#[derive(Debug, Deserialize)]
383#[expect(dead_code)]
384struct GeminiContentResponse {
385 role: String,
386 parts: Vec<GeminiResponsePart>,
387}
388
389fn map_gemini_usage(usage: &GeminiUsageMetadata) -> Option<TokenUsage> {
390 let prompt = usage.prompt.map(i64::from);
391 let candidates = usage.candidates.map(i64::from);
392 let total = usage.total.map(i64::from);
393
394 let input_tokens = prompt.or_else(|| match (total, candidates) {
395 (Some(total_tokens), Some(output_tokens)) if total_tokens >= output_tokens => {
396 Some(total_tokens - output_tokens)
397 }
398 _ => None,
399 })?;
400
401 let output_tokens = candidates.or_else(|| match (total, prompt) {
402 (Some(total_tokens), Some(input_tokens)) if total_tokens >= input_tokens => {
403 Some(total_tokens - input_tokens)
404 }
405 _ => None,
406 })?;
407
408 let total_tokens = total.or_else(|| input_tokens.checked_add(output_tokens))?;
409
410 Some(TokenUsage::new(
411 u32::try_from(input_tokens).ok()?,
412 u32::try_from(output_tokens).ok()?,
413 u32::try_from(total_tokens).ok()?,
414 ))
415}
416
417fn user_image_to_request_part(
418 image: &crate::app::conversation::ImageContent,
419) -> Result<GeminiRequestPart, ApiError> {
420 match &image.source {
421 ImageSource::DataUrl { data_url } => {
422 let Some((meta, payload)) = data_url.split_once(',') else {
423 return Err(ApiError::InvalidRequest {
424 provider: "google".to_string(),
425 details: "Image data URL is missing ',' separator".to_string(),
426 });
427 };
428
429 let Some(meta_body) = meta.strip_prefix("data:") else {
430 return Err(ApiError::InvalidRequest {
431 provider: "google".to_string(),
432 details: "Image data URL must start with 'data:'".to_string(),
433 });
434 };
435
436 if !meta_body
437 .split(';')
438 .any(|segment| segment.eq_ignore_ascii_case("base64"))
439 {
440 return Err(ApiError::InvalidRequest {
441 provider: "google".to_string(),
442 details: "Gemini image data URLs must be base64 encoded".to_string(),
443 });
444 }
445
446 let mime_type = meta_body
447 .split(';')
448 .next()
449 .filter(|value| !value.trim().is_empty())
450 .unwrap_or("application/octet-stream")
451 .to_string();
452
453 if !mime_type.starts_with("image/") {
454 return Err(ApiError::UnsupportedFeature {
455 provider: "google".to_string(),
456 feature: "image input mime type".to_string(),
457 details: format!("Unsupported image media type '{}'", mime_type),
458 });
459 }
460
461 Ok(GeminiRequestPart::InlineData {
462 inline_data: GeminiBlob {
463 mime_type,
464 data: payload.to_string(),
465 },
466 })
467 }
468 ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
469 file_data: GeminiFileData {
470 mime_type: image.mime_type.clone(),
471 file_uri: url.clone(),
472 },
473 }),
474 ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
475 provider: "google".to_string(),
476 feature: "image input source".to_string(),
477 details: format!(
478 "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
479 relative_path
480 ),
481 }),
482 }
483}
484
485fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
486 let mut converted = Vec::new();
487
488 for msg in messages {
489 match &msg.data {
490 crate::app::conversation::MessageData::User { content, .. } => {
491 let mut parts = Vec::new();
492 for user_content in content {
493 match user_content {
494 UserContent::Text { text } => {
495 parts.push(GeminiRequestPart::Text { text: text.clone() });
496 }
497 UserContent::Image { image } => {
498 parts.push(user_image_to_request_part(image)?);
499 }
500 UserContent::CommandExecution {
501 command,
502 stdout,
503 stderr,
504 exit_code,
505 } => {
506 parts.push(GeminiRequestPart::Text {
507 text: UserContent::format_command_execution_as_xml(
508 command, stdout, stderr, *exit_code,
509 ),
510 });
511 }
512 }
513 }
514
515 if !parts.is_empty() {
516 converted.push(GeminiContent {
517 role: "user".to_string(),
518 parts,
519 });
520 }
521 }
522 crate::app::conversation::MessageData::Assistant { content, .. } => {
523 let mut parts = Vec::new();
524 for assistant_content in content {
525 match assistant_content {
526 AssistantContent::Text { text } => {
527 parts.push(GeminiRequestPart::Text { text: text.clone() });
528 }
529 AssistantContent::Image { image } => {
530 match user_image_to_request_part(image) {
531 Ok(part) => parts.push(part),
532 Err(err) => {
533 debug!(
534 target: "gemini::convert_messages",
535 "Skipping unsupported assistant image block: {}",
536 err
537 );
538 }
539 }
540 }
541 AssistantContent::ToolCall {
542 tool_call,
543 thought_signature,
544 } => parts.push(GeminiRequestPart::FunctionCall {
545 function_call: GeminiFunctionCall {
546 name: tool_call.name.clone(),
547 args: tool_call.parameters.clone(),
548 },
549 thought_signature: thought_signature
550 .as_ref()
551 .map(|signature| signature.as_str().to_string()),
552 }),
553 AssistantContent::Thought { .. } => {
554 }
556 }
557 }
558
559 converted.push(GeminiContent {
560 role: "model".to_string(),
561 parts,
562 });
563 }
564 crate::app::conversation::MessageData::Tool {
565 tool_use_id,
566 result,
567 ..
568 } => {
569 let result_value = match result {
570 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
571 _ => serde_json::to_value(result)
572 .unwrap_or_else(|_| Value::String(result.llm_format())),
573 };
574
575 let parts = vec![GeminiRequestPart::FunctionResponse {
576 function_response: GeminiFunctionResponse {
577 name: tool_use_id.clone(),
578 response: GeminiResponseContent {
579 content: result_value,
580 },
581 },
582 }];
583
584 converted.push(GeminiContent {
585 role: "function".to_string(),
586 parts,
587 });
588 }
589 }
590 }
591
592 Ok(converted)
593}
594
595fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
596 let reference = schema.get("$ref").and_then(|v| v.as_str())?;
597 let path = reference.strip_prefix("#/")?;
598 let mut current = root;
599 for segment in path.split('/') {
600 current = current.get(segment)?;
601 }
602 Some(current)
603}
604
605fn infer_type_from_enum(values: &[Value]) -> Option<String> {
606 let mut has_string = false;
607 let mut has_number = false;
608 let mut has_bool = false;
609 let mut has_object = false;
610 let mut has_array = false;
611
612 for value in values {
613 match value {
614 Value::String(_) => has_string = true,
615 Value::Number(_) => has_number = true,
616 Value::Bool(_) => has_bool = true,
617 Value::Object(_) => has_object = true,
618 Value::Array(_) => has_array = true,
619 Value::Null => {}
620 }
621 }
622
623 let kind_count = u8::from(has_string)
624 + u8::from(has_number)
625 + u8::from(has_bool)
626 + u8::from(has_object)
627 + u8::from(has_array);
628
629 if kind_count != 1 {
630 return None;
631 }
632
633 if has_string {
634 Some("string".to_string())
635 } else if has_number {
636 Some("number".to_string())
637 } else if has_bool {
638 Some("boolean".to_string())
639 } else if has_object {
640 Some("object".to_string())
641 } else if has_array {
642 Some("array".to_string())
643 } else {
644 None
645 }
646}
647
648fn normalize_type(value: &Value) -> Value {
649 if let Some(type_str) = value.as_str() {
650 return Value::String(type_str.to_string());
651 }
652
653 if let Some(type_array) = value.as_array()
654 && let Some(primary_type) = type_array
655 .iter()
656 .find_map(|v| if v.is_null() { None } else { v.as_str() })
657 {
658 return Value::String(primary_type.to_string());
659 }
660
661 Value::String("string".to_string())
662}
663
664fn extract_enum_values(value: &Value) -> Vec<Value> {
665 let Some(obj) = value.as_object() else {
666 return Vec::new();
667 };
668
669 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
670 return enum_values
671 .iter()
672 .filter(|v| !v.is_null())
673 .cloned()
674 .collect();
675 }
676
677 if let Some(const_value) = obj.get("const") {
678 if const_value.is_null() {
679 return Vec::new();
680 }
681 return vec![const_value.clone()];
682 }
683
684 Vec::new()
685}
686
687fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
688 match properties.get_mut(key) {
689 None => {
690 properties.insert(key.to_string(), value.clone());
691 }
692 Some(existing) => {
693 if existing == value {
694 return;
695 }
696
697 let existing_values = extract_enum_values(existing);
698 let incoming_values = extract_enum_values(value);
699 if incoming_values.is_empty() && existing_values.is_empty() {
700 return;
701 }
702
703 let mut combined = existing_values;
704 for item in incoming_values {
705 if !combined.contains(&item) {
706 combined.push(item);
707 }
708 }
709
710 if combined.is_empty() {
711 return;
712 }
713
714 if let Some(obj) = existing.as_object_mut() {
715 obj.remove("const");
716 obj.insert("enum".to_string(), Value::Array(combined.clone()));
717 if !obj.contains_key("type")
718 && let Some(inferred) = infer_type_from_enum(&combined)
719 {
720 obj.insert("type".to_string(), Value::String(inferred));
721 }
722 }
723 }
724 }
725}
726
727fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
728 let mut merged_props = serde_json::Map::new();
729 let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
730 let mut enum_values: Vec<Value> = Vec::new();
731 let mut type_candidates: Vec<String> = Vec::new();
732
733 for variant in variants {
734 let sanitized = sanitize_for_gemini(root, variant);
735
736 if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
737 type_candidates.push(schema_type.to_string());
738 }
739
740 if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
741 for (key, value) in props {
742 merge_property(&mut merged_props, key, value);
743 }
744 }
745
746 if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
747 let req_set: std::collections::BTreeSet<String> = req
748 .iter()
749 .filter_map(|item| item.as_str().map(|s| s.to_string()))
750 .collect();
751
752 required_intersection = match required_intersection.take() {
753 None => Some(req_set),
754 Some(existing) => Some(
755 existing
756 .intersection(&req_set)
757 .cloned()
758 .collect::<std::collections::BTreeSet<String>>(),
759 ),
760 };
761 }
762
763 if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
764 for value in values {
765 if value.is_null() {
766 continue;
767 }
768 if !enum_values.contains(value) {
769 enum_values.push(value.clone());
770 }
771 }
772 }
773 }
774
775 let schema_type = if !merged_props.is_empty() {
776 "object".to_string()
777 } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
778 inferred
779 } else if let Some(first) = type_candidates.first() {
780 first.clone()
781 } else {
782 "string".to_string()
783 };
784
785 let mut merged = serde_json::Map::new();
786 merged.insert("type".to_string(), Value::String(schema_type));
787
788 if !merged_props.is_empty() {
789 merged.insert("properties".to_string(), Value::Object(merged_props));
790 }
791
792 if let Some(required_set) = required_intersection
793 && !required_set.is_empty()
794 {
795 merged.insert(
796 "required".to_string(),
797 Value::Array(
798 required_set
799 .into_iter()
800 .map(Value::String)
801 .collect::<Vec<_>>(),
802 ),
803 );
804 }
805
806 if !enum_values.is_empty() {
807 merged.insert("enum".to_string(), Value::Array(enum_values));
808 }
809
810 Value::Object(merged)
811}
812
813fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
814 if let Some(resolved) = resolve_ref(root, schema) {
815 return sanitize_for_gemini(root, resolved);
816 }
817
818 let Some(obj) = schema.as_object() else {
819 return schema.clone();
820 };
821
822 if let Some(union) = obj
823 .get("oneOf")
824 .or_else(|| obj.get("anyOf"))
825 .or_else(|| obj.get("allOf"))
826 .and_then(|v| v.as_array())
827 {
828 return merge_union_schemas(root, union);
829 }
830
831 let mut out = serde_json::Map::new();
832 for (key, value) in obj {
833 match key.as_str() {
834 "$ref"
835 | "$defs"
836 | "oneOf"
837 | "anyOf"
838 | "allOf"
839 | "const"
840 | "additionalProperties"
841 | "default"
842 | "examples"
843 | "title"
844 | "pattern"
845 | "minLength"
846 | "maxLength"
847 | "minimum"
848 | "maximum"
849 | "minItems"
850 | "maxItems"
851 | "uniqueItems"
852 | "deprecated" => {}
853 "type" => {
854 out.insert("type".to_string(), normalize_type(value));
855 }
856 "properties" => {
857 if let Some(props) = value.as_object() {
858 let mut sanitized_props = serde_json::Map::new();
859 for (prop_key, prop_value) in props {
860 sanitized_props
861 .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
862 }
863 out.insert("properties".to_string(), Value::Object(sanitized_props));
864 }
865 }
866 "items" => {
867 out.insert("items".to_string(), sanitize_for_gemini(root, value));
868 }
869 "enum" => {
870 let values = value
871 .as_array()
872 .map(|items| {
873 items
874 .iter()
875 .filter(|v| !v.is_null())
876 .cloned()
877 .collect::<Vec<_>>()
878 })
879 .unwrap_or_default();
880 out.insert("enum".to_string(), Value::Array(values));
881 }
882 _ => {
883 out.insert(key.clone(), sanitize_for_gemini(root, value));
884 }
885 }
886 }
887
888 if let Some(const_value) = obj.get("const")
889 && !const_value.is_null()
890 {
891 out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
892 if !out.contains_key("type")
893 && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
894 {
895 out.insert("type".to_string(), Value::String(inferred));
896 }
897 }
898
899 if out.get("type") == Some(&Value::String("object".to_string()))
900 && !out.contains_key("properties")
901 {
902 out.insert(
903 "properties".to_string(),
904 Value::Object(serde_json::Map::new()),
905 );
906 }
907
908 if !out.contains_key("type") {
909 if out.contains_key("properties") {
910 out.insert("type".to_string(), Value::String("object".to_string()));
911 } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
912 && let Some(inferred) = infer_type_from_enum(enum_values)
913 {
914 out.insert("type".to_string(), Value::String(inferred));
915 }
916 }
917
918 Value::Object(out)
919}
920
921fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
922 if let Some(prop_map_orig) = property_value.as_object() {
923 let mut simplified_prop = prop_map_orig.clone();
924
925 if simplified_prop.remove("additionalProperties").is_some() {
927 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
928 }
929
930 if let Some(type_val) = simplified_prop.get_mut("type") {
932 if let Some(type_array) = type_val.as_array() {
933 if let Some(primary_type) = type_array
934 .iter()
935 .find_map(|v| if v.is_null() { None } else { v.as_str() })
936 {
937 *type_val = serde_json::Value::String(primary_type.to_string());
938 } else {
939 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
940 *type_val = serde_json::Value::String("string".to_string());
941 }
942 } else if !type_val.is_string() {
943 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
944 *type_val = serde_json::Value::String("string".to_string());
945 }
946 }
948
949 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
951 && let Some(format_val) = simplified_prop.get_mut("format")
952 && format_val.as_str() == Some("uint64")
953 {
954 *format_val = serde_json::Value::String("int64".to_string());
955 }
958
959 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
961 let should_remove_format = simplified_prop
962 .get("format")
963 .and_then(|f| f.as_str())
964 .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
965
966 if should_remove_format
967 && let Some(format_val) = simplified_prop.remove("format")
968 && let Some(format_str) = format_val.as_str()
969 {
970 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
971 }
972
973 if simplified_prop.remove("minLength").is_some() {
975 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
976 }
977 if simplified_prop.remove("maxLength").is_some() {
978 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
979 }
980 if simplified_prop.remove("pattern").is_some() {
981 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
982 }
983 }
984
985 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
987 && let Some(items_val) = simplified_prop.get_mut("items")
988 {
989 *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
990 }
991
992 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
994 && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
995 {
996 let simplified_nested_props: serde_json::Map<String, Value> = props
997 .iter()
998 .map(|(nested_key, nested_value)| {
999 (
1000 nested_key.clone(),
1001 simplify_property_schema(
1002 &format!("{key}.{nested_key}"),
1003 tool_name,
1004 nested_value,
1005 ),
1006 )
1007 })
1008 .collect();
1009 *props = simplified_nested_props;
1010 }
1011
1012 serde_json::Value::Object(simplified_prop)
1013 } else {
1014 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
1015 property_value.clone() }
1017}
1018
1019fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
1020 let function_declarations = tools
1021 .into_iter()
1022 .map(|tool| {
1023 let root_schema = tool.input_schema.as_value();
1024 let summary = tool.input_schema.summary();
1025 let schema_type = if summary.schema_type.is_empty() {
1026 "object".to_string()
1027 } else {
1028 summary.schema_type
1029 };
1030
1031 let simplified_properties = summary
1033 .properties
1034 .iter()
1035 .map(|(key, value)| {
1036 let sanitized = sanitize_for_gemini(root_schema, value);
1037 (
1038 key.clone(),
1039 simplify_property_schema(key, &tool.name, &sanitized),
1040 )
1041 })
1042 .collect();
1043
1044 let parameters = GeminiParameterSchema {
1046 schema_type, properties: simplified_properties, required: summary.required, };
1050
1051 GeminiFunctionDeclaration {
1052 name: tool.name,
1053 description: tool.description,
1054 parameters,
1055 }
1056 })
1057 .collect();
1058
1059 vec![GeminiTool {
1060 function_declarations,
1061 }]
1062}
1063
1064fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1065 if let Some(feedback) = &response.prompt_feedback
1067 && let Some(reason) = &feedback.block_reason
1068 {
1069 let details = format!(
1070 "Prompt blocked due to {:?}. Safety ratings: {:?}",
1071 reason, feedback.safety_ratings
1072 );
1073 warn!(target: "gemini::convert_response", "{}", details);
1074 return Err(ApiError::RequestBlocked {
1076 provider: "google".to_string(), details,
1078 });
1079 }
1080
1081 let candidates = if let Some(cands) = response.candidates {
1083 if cands.is_empty() {
1084 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1087 return Err(ApiError::NoChoices {
1089 provider: "google".to_string(),
1090 });
1091 }
1092 cands } else {
1094 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1095 return Err(ApiError::NoChoices {
1097 provider: "google".to_string(),
1098 });
1099 };
1100
1101 let candidate = &candidates[0];
1104
1105 if let Some(reason) = &candidate.finish_reason {
1107 match reason {
1108 GeminiFinishReason::Stop => { }
1109 GeminiFinishReason::MaxTokens => {
1110 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1111 }
1112 GeminiFinishReason::Safety => {
1113 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1114 }
1116 GeminiFinishReason::Recitation => {
1117 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1118 }
1119 GeminiFinishReason::MalformedFunctionCall => {
1120 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1121 }
1122 _ => {
1123 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1124 }
1125 }
1126 }
1127
1128 if let Some(usage) = &response.usage_metadata {
1130 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1131 usage.prompt, usage.candidates, usage.total);
1132 }
1133 let usage = response.usage_metadata.as_ref().and_then(map_gemini_usage);
1134
1135 let content: Vec<AssistantContent> = candidate
1136 .content .parts .iter()
1139 .filter_map(|part| { if part.thought {
1142 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1143 if let GeminiResponsePartData::Text { text } = &part.data {
1145 Some(AssistantContent::Thought {
1146 thought: ThoughtContent::Simple {
1147 text: text.clone(),
1148 },
1149 })
1150 } else {
1151 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1152 None
1153 }
1154 } else {
1155 match &part.data {
1157 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1158 text: text.clone(),
1159 }),
1160 GeminiResponsePartData::InlineData { inline_data } => {
1161 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1162 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1163 }
1164 GeminiResponsePartData::FunctionCall { function_call } => {
1165 Some(AssistantContent::ToolCall {
1166 tool_call: steer_tools::ToolCall {
1167 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
1169 parameters: function_call.args.clone(),
1170 },
1171 thought_signature: part
1172 .thought_signature
1173 .clone()
1174 .map(ThoughtSignature::new),
1175 })
1176 }
1177 GeminiResponsePartData::FileData { file_data } => {
1178 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1179 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1180 }
1181 GeminiResponsePartData::ExecutableCode { executable_code } => {
1182 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1183 executable_code.language);
1184 Some(AssistantContent::Text {
1185 text: format!(
1186 "```{}
1187{}
1188```",
1189 executable_code.language.to_lowercase(),
1190 executable_code.code
1191 ),
1192 })
1193 }
1194 }
1195 }
1196 })
1197 .collect();
1198
1199 Ok(CompletionResponse { content, usage })
1200}
1201
1202#[async_trait]
1203impl Provider for GeminiClient {
1204 fn name(&self) -> &'static str {
1205 "google"
1206 }
1207
1208 async fn complete(
1209 &self,
1210 model_id: &ModelId,
1211 messages: Vec<AppMessage>,
1212 system: Option<SystemContext>,
1213 tools: Option<Vec<ToolSchema>>,
1214 _call_options: Option<ModelParameters>,
1215 token: CancellationToken,
1216 ) -> Result<CompletionResponse, ApiError> {
1217 let model_name = &model_id.id; let url = format!(
1219 "{}/models/{}:generateContent?key={}",
1220 GEMINI_API_BASE, model_name, self.api_key
1221 );
1222
1223 let gemini_contents = convert_messages(messages)?;
1224
1225 let system_instruction = system
1226 .and_then(|context| context.render())
1227 .map(|instructions| GeminiSystemInstruction {
1228 parts: vec![GeminiRequestPart::Text { text: instructions }],
1229 });
1230
1231 let gemini_tools = tools.map(convert_tools);
1232
1233 let (temperature, top_p, max_output_tokens) = {
1235 let opts = _call_options.as_ref();
1236 (
1237 opts.and_then(|o| o.temperature).or(Some(1.0)),
1238 opts.and_then(|o| o.top_p).or(Some(0.95)),
1239 opts.and_then(|o| o.max_output_tokens)
1240 .map(|v| v as i32)
1241 .or(Some(65536)),
1242 )
1243 };
1244 let thinking_config = _call_options
1245 .as_ref()
1246 .and_then(|o| o.thinking_config)
1247 .and_then(|tc| {
1248 if tc.enabled {
1249 Some(GeminiThinkingConfig {
1250 include_thoughts: tc.include_thoughts,
1251 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1252 })
1253 } else {
1254 None
1255 }
1256 });
1257
1258 let request = GeminiRequest {
1259 contents: gemini_contents,
1260 system_instruction,
1261 tools: gemini_tools,
1262 generation_config: Some(GeminiGenerationConfig {
1263 max_output_tokens,
1264 temperature,
1265 top_p,
1266 thinking_config,
1267 ..Default::default()
1268 }),
1269 };
1270
1271 let response = tokio::select! {
1272 biased;
1273 () = token.cancelled() => {
1274 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1275 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1276 }
1277 res = self.client.post(&url).json(&request).send() => {
1278 res.map_err(ApiError::Network)?
1279 }
1280 };
1281 let status = response.status();
1282
1283 if status != StatusCode::OK {
1284 let error_text = response.text().await.map_err(ApiError::Network)?;
1285 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1286 return Err(map_http_status_to_api_error(
1287 self.name(),
1288 status.as_u16(),
1289 error_text,
1290 ));
1291 }
1292
1293 let response_text = response.text().await.map_err(ApiError::Network)?;
1294
1295 match serde_json::from_str::<GeminiResponse>(&response_text) {
1296 Ok(gemini_response) => {
1297 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1298 provider: self.name().to_string(),
1299 details: e.to_string(),
1300 })
1301 }
1302 Err(e) => {
1303 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1304 Err(ApiError::ResponseParsingError {
1305 provider: self.name().to_string(),
1306 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1307 })
1308 }
1309 }
1310 }
1311
1312 async fn stream_complete(
1313 &self,
1314 model_id: &ModelId,
1315 messages: Vec<AppMessage>,
1316 system: Option<SystemContext>,
1317 tools: Option<Vec<ToolSchema>>,
1318 _call_options: Option<ModelParameters>,
1319 token: CancellationToken,
1320 ) -> Result<CompletionStream, ApiError> {
1321 let model_name = &model_id.id;
1322 let url = format!(
1323 "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1324 GEMINI_API_BASE, model_name, self.api_key
1325 );
1326
1327 let gemini_contents = convert_messages(messages)?;
1328
1329 let system_instruction = system
1330 .and_then(|context| context.render())
1331 .map(|instructions| GeminiSystemInstruction {
1332 parts: vec![GeminiRequestPart::Text { text: instructions }],
1333 });
1334
1335 let gemini_tools = tools.map(convert_tools);
1336
1337 let (temperature, top_p, max_output_tokens) = {
1338 let opts = _call_options.as_ref();
1339 (
1340 opts.and_then(|o| o.temperature).or(Some(1.0)),
1341 opts.and_then(|o| o.top_p).or(Some(0.95)),
1342 opts.and_then(|o| o.max_output_tokens)
1343 .map(|v| v as i32)
1344 .or(Some(65536)),
1345 )
1346 };
1347 let thinking_config = _call_options
1348 .as_ref()
1349 .and_then(|o| o.thinking_config)
1350 .and_then(|tc| {
1351 if tc.enabled {
1352 Some(GeminiThinkingConfig {
1353 include_thoughts: tc.include_thoughts,
1354 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1355 })
1356 } else {
1357 None
1358 }
1359 });
1360
1361 let request = GeminiRequest {
1362 contents: gemini_contents,
1363 system_instruction,
1364 tools: gemini_tools,
1365 generation_config: Some(GeminiGenerationConfig {
1366 max_output_tokens,
1367 temperature,
1368 top_p,
1369 thinking_config,
1370 ..Default::default()
1371 }),
1372 };
1373
1374 let response = tokio::select! {
1375 biased;
1376 () = token.cancelled() => {
1377 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1378 }
1379 res = self.client.post(&url).json(&request).send() => {
1380 res.map_err(ApiError::Network)?
1381 }
1382 };
1383
1384 let status = response.status();
1385 if status != StatusCode::OK {
1386 let error_text = response.text().await.map_err(ApiError::Network)?;
1387 error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1388 return Err(map_http_status_to_api_error(
1389 self.name(),
1390 status.as_u16(),
1391 error_text,
1392 ));
1393 }
1394
1395 let byte_stream = response.bytes_stream();
1396 let sse_stream = parse_sse_stream(byte_stream);
1397
1398 Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1399 }
1400}
1401
1402impl GeminiClient {
1403 fn convert_gemini_stream(
1404 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1405 + Unpin
1406 + Send
1407 + 'static,
1408 token: CancellationToken,
1409 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1410 async_stream::stream! {
1411 let mut content: Vec<AssistantContent> = Vec::new();
1412 let mut latest_usage: Option<TokenUsage> = None;
1413 loop {
1414 if token.is_cancelled() {
1415 yield StreamChunk::Error(StreamError::Cancelled);
1416 break;
1417 }
1418
1419 let event_result = tokio::select! {
1420 biased;
1421 () = token.cancelled() => {
1422 yield StreamChunk::Error(StreamError::Cancelled);
1423 break;
1424 }
1425 event = sse_stream.next() => event
1426 };
1427
1428 let Some(event_result) = event_result else {
1429 let content = std::mem::take(&mut content);
1430 yield StreamChunk::MessageComplete(CompletionResponse {
1431 content,
1432 usage: latest_usage,
1433 });
1434 break;
1435 };
1436
1437 let event = match event_result {
1438 Ok(e) => e,
1439 Err(e) => {
1440 yield StreamChunk::Error(StreamError::SseParse(e));
1441 break;
1442 }
1443 };
1444
1445 let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1446 Ok(c) => c,
1447 Err(e) => {
1448 debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1449 continue;
1450 }
1451 };
1452
1453 if let Some(usage) = chunk.usage_metadata.as_ref()
1454 && let Some(mapped) = map_gemini_usage(usage) {
1455 latest_usage = Some(mapped);
1456 }
1457
1458 if let Some(candidates) = chunk.candidates {
1459 for candidate in candidates {
1460 for part in candidate.content.parts {
1461 let GeminiResponsePart {
1462 thought,
1463 thought_signature,
1464 data,
1465 } = part;
1466 let thought_signature =
1467 thought_signature.map(ThoughtSignature::new);
1468
1469 if thought {
1470 if let GeminiResponsePartData::Text { text } = data {
1471 match content.last_mut() {
1472 Some(AssistantContent::Thought {
1473 thought: ThoughtContent::Simple { text: buf },
1474 }) => buf.push_str(&text),
1475 _ => {
1476 content.push(AssistantContent::Thought {
1477 thought: ThoughtContent::Simple { text: text.clone() },
1478 });
1479 }
1480 }
1481 yield StreamChunk::ThinkingDelta(text);
1482 }
1483 } else {
1484 match data {
1485 GeminiResponsePartData::Text { text } => {
1486 match content.last_mut() {
1487 Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1488 _ => {
1489 content.push(AssistantContent::Text { text: text.clone() });
1490 }
1491 }
1492 yield StreamChunk::TextDelta(text);
1493 }
1494 GeminiResponsePartData::FunctionCall { function_call } => {
1495 let id = uuid::Uuid::new_v4().to_string();
1496 content.push(AssistantContent::ToolCall {
1497 tool_call: steer_tools::ToolCall {
1498 id: id.clone(),
1499 name: function_call.name.clone(),
1500 parameters: function_call.args.clone(),
1501 },
1502 thought_signature,
1503 });
1504 yield StreamChunk::ToolUseStart {
1505 id: id.clone(),
1506 name: function_call.name,
1507 };
1508 yield StreamChunk::ToolUseInputDelta {
1509 id,
1510 delta: function_call.args.to_string(),
1511 };
1512 }
1513 _ => {}
1514 }
1515 }
1516 }
1517 }
1518 }
1519 }
1520 }
1521 }
1522}
1523
1524#[cfg(test)]
1525mod tests {
1526 use super::*;
1527 use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1528 use serde_json::json;
1529
1530 #[test]
1531 fn test_simplify_property_schema_removes_additional_properties() {
1532 let property_value = json!({
1533 "type": "object",
1534 "properties": {
1535 "name": {"type": "string"}
1536 },
1537 "additionalProperties": false
1538 });
1539
1540 let expected = json!({
1541 "type": "object",
1542 "properties": {
1543 "name": {"type": "string"}
1544 }
1545 });
1546
1547 let result = simplify_property_schema("testProp", "testTool", &property_value);
1548 assert_eq!(result, expected);
1549 }
1550
1551 #[test]
1552 fn test_simplify_property_schema_removes_unsupported_string_formats() {
1553 let property_value = json!({
1554 "type": "string",
1555 "format": "uri",
1556 "minLength": 1,
1557 "maxLength": 100,
1558 "pattern": "^https://"
1559 });
1560
1561 let expected = json!({
1562 "type": "string"
1563 });
1564
1565 let result = simplify_property_schema("urlProp", "testTool", &property_value);
1566 assert_eq!(result, expected);
1567 }
1568
1569 #[test]
1570 fn test_simplify_property_schema_keeps_supported_string_formats() {
1571 let property_value = json!({
1572 "type": "string",
1573 "format": "date-time"
1574 });
1575
1576 let expected = json!({
1577 "type": "string",
1578 "format": "date-time"
1579 });
1580
1581 let result = simplify_property_schema("dateProp", "testTool", &property_value);
1582 assert_eq!(result, expected);
1583 }
1584
1585 #[test]
1586 fn test_simplify_property_schema_handles_array_types() {
1587 let property_value = json!({
1588 "type": ["string", "null"],
1589 "format": "email"
1590 });
1591
1592 let expected = json!({
1593 "type": "string"
1594 });
1595
1596 let result = simplify_property_schema("emailProp", "testTool", &property_value);
1597 assert_eq!(result, expected);
1598 }
1599
1600 #[test]
1601 fn test_simplify_property_schema_recursively_handles_array_items() {
1602 let property_value = json!({
1603 "type": "array",
1604 "items": {
1605 "type": "object",
1606 "properties": {
1607 "url": {
1608 "type": "string",
1609 "format": "uri"
1610 }
1611 },
1612 "additionalProperties": false
1613 }
1614 });
1615
1616 let expected = json!({
1617 "type": "array",
1618 "items": {
1619 "type": "object",
1620 "properties": {
1621 "url": {
1622 "type": "string"
1623 }
1624 }
1625 }
1626 });
1627
1628 let result = simplify_property_schema("linksProp", "testTool", &property_value);
1629 assert_eq!(result, expected);
1630 }
1631
1632 #[test]
1633 fn test_simplify_property_schema_recursively_handles_nested_objects() {
1634 let property_value = json!({
1635 "type": "object",
1636 "properties": {
1637 "nested": {
1638 "type": "object",
1639 "properties": {
1640 "field": {
1641 "type": "string",
1642 "format": "hostname"
1643 }
1644 },
1645 "additionalProperties": true
1646 }
1647 },
1648 "additionalProperties": false
1649 });
1650
1651 let expected = json!({
1652 "type": "object",
1653 "properties": {
1654 "nested": {
1655 "type": "object",
1656 "properties": {
1657 "field": {
1658 "type": "string"
1659 }
1660 }
1661 }
1662 }
1663 });
1664
1665 let result = simplify_property_schema("complexProp", "testTool", &property_value);
1666 assert_eq!(result, expected);
1667 }
1668
1669 #[test]
1670 fn test_simplify_property_schema_fixes_uint64_format() {
1671 let property_value = json!({
1672 "type": "integer",
1673 "format": "uint64"
1674 });
1675
1676 let expected = json!({
1677 "type": "integer",
1678 "format": "int64"
1679 });
1680
1681 let result = simplify_property_schema("idProp", "testTool", &property_value);
1682 assert_eq!(result, expected);
1683 }
1684
1685 #[test]
1686 fn test_convert_tools_integration() {
1687 use steer_tools::{InputSchema, ToolSchema};
1688
1689 let tool = ToolSchema {
1690 name: "create_issue".to_string(),
1691 display_name: "Create Issue".to_string(),
1692 description: "Create an issue".to_string(),
1693 input_schema: InputSchema::object(
1694 {
1695 let mut props = serde_json::Map::new();
1696 props.insert(
1697 "title".to_string(),
1698 json!({
1699 "type": "string",
1700 "minLength": 1
1701 }),
1702 );
1703 props.insert(
1704 "links".to_string(),
1705 json!({
1706 "type": "array",
1707 "items": {
1708 "type": "object",
1709 "properties": {
1710 "url": {
1711 "type": "string",
1712 "format": "uri"
1713 }
1714 },
1715 "additionalProperties": false
1716 }
1717 }),
1718 );
1719 props
1720 },
1721 vec!["title".to_string()],
1722 ),
1723 };
1724
1725 let expected_tools = vec![GeminiTool {
1726 function_declarations: vec![GeminiFunctionDeclaration {
1727 name: "create_issue".to_string(),
1728 description: "Create an issue".to_string(),
1729 parameters: GeminiParameterSchema {
1730 schema_type: "object".to_string(),
1731 properties: {
1732 let mut props = serde_json::Map::new();
1733 props.insert(
1734 "title".to_string(),
1735 json!({
1736 "type": "string"
1737 }),
1738 );
1739 props.insert(
1740 "links".to_string(),
1741 json!({
1742 "type": "array",
1743 "items": {
1744 "type": "object",
1745 "properties": {
1746 "url": {
1747 "type": "string"
1748 }
1749 }
1750 }
1751 }),
1752 );
1753 props
1754 },
1755 required: vec!["title".to_string()],
1756 },
1757 }],
1758 }];
1759
1760 let result = convert_tools(vec![tool]);
1761 assert_eq!(result, expected_tools);
1762 }
1763
1764 #[test]
1765 fn test_convert_messages_includes_data_url_image_as_inline_data() {
1766 let messages = vec![Message {
1767 data: MessageData::User {
1768 content: vec![UserContent::Image {
1769 image: ImageContent {
1770 mime_type: "image/png".to_string(),
1771 source: ImageSource::DataUrl {
1772 data_url: "data:image/png;base64,Zm9v".to_string(),
1773 },
1774 width: None,
1775 height: None,
1776 bytes: None,
1777 sha256: None,
1778 },
1779 }],
1780 },
1781 timestamp: 1,
1782 id: "msg-image".to_string(),
1783 parent_message_id: None,
1784 }];
1785
1786 let converted = convert_messages(messages).expect("convert messages");
1787 assert_eq!(converted.len(), 1);
1788 let first = &converted[0];
1789 assert_eq!(first.role, "user");
1790 assert_eq!(first.parts.len(), 1);
1791
1792 let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1793 assert_eq!(json["inlineData"]["mimeType"], "image/png");
1794 assert_eq!(json["inlineData"]["data"], "Zm9v");
1795 }
1796
1797 #[test]
1798 fn test_convert_messages_rejects_session_file_image_source() {
1799 let messages = vec![Message {
1800 data: MessageData::User {
1801 content: vec![UserContent::Image {
1802 image: ImageContent {
1803 mime_type: "image/png".to_string(),
1804 source: ImageSource::SessionFile {
1805 relative_path: "session-1/image.png".to_string(),
1806 },
1807 width: None,
1808 height: None,
1809 bytes: None,
1810 sha256: None,
1811 },
1812 }],
1813 },
1814 timestamp: 1,
1815 id: "msg-image".to_string(),
1816 parent_message_id: None,
1817 }];
1818
1819 let err = convert_messages(messages).expect_err("expected unsupported feature");
1820 match err {
1821 ApiError::UnsupportedFeature {
1822 provider,
1823 feature,
1824 details,
1825 } => {
1826 assert_eq!(provider, "google");
1827 assert_eq!(feature, "image input source");
1828 assert!(details.contains("session file"));
1829 }
1830 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1831 }
1832 }
1833
1834 #[test]
1835 fn test_map_gemini_usage_derives_missing_counter() {
1836 let usage = GeminiUsageMetadata {
1837 prompt: None,
1838 candidates: Some(7),
1839 total: Some(19),
1840 };
1841
1842 assert_eq!(map_gemini_usage(&usage), Some(TokenUsage::new(12, 7, 19)));
1843 }
1844
1845 #[test]
1846 fn test_map_gemini_usage_returns_none_when_counters_not_derivable() {
1847 let usage = GeminiUsageMetadata {
1848 prompt: Some(9),
1849 candidates: None,
1850 total: None,
1851 };
1852
1853 assert_eq!(map_gemini_usage(&usage), None);
1854 }
1855
1856 #[test]
1857 fn test_convert_response_maps_usage_metadata() {
1858 let response = GeminiResponse {
1859 candidates: Some(vec![GeminiCandidate {
1860 content: GeminiContentResponse {
1861 role: "model".to_string(),
1862 parts: vec![GeminiResponsePart {
1863 thought: false,
1864 thought_signature: None,
1865 data: GeminiResponsePartData::Text {
1866 text: "Hello".to_string(),
1867 },
1868 }],
1869 },
1870 finish_reason: Some(GeminiFinishReason::Stop),
1871 safety_ratings: None,
1872 citation_metadata: None,
1873 }]),
1874 prompt_feedback: None,
1875 usage_metadata: Some(GeminiUsageMetadata {
1876 prompt: Some(11),
1877 candidates: Some(5),
1878 total: Some(16),
1879 }),
1880 };
1881
1882 let converted = convert_response(response).expect("response should convert");
1883
1884 assert_eq!(converted.usage, Some(TokenUsage::new(11, 5, 16)));
1885 assert!(matches!(
1886 converted.content.first(),
1887 Some(AssistantContent::Text { text }) if text == "Hello"
1888 ));
1889 }
1890
1891 #[tokio::test]
1892 async fn test_convert_gemini_stream_text_deltas() {
1893 use crate::api::provider::StreamChunk;
1894 use crate::api::sse::SseEvent;
1895 use futures::StreamExt;
1896 use futures::stream;
1897 use std::pin::pin;
1898 use tokio_util::sync::CancellationToken;
1899
1900 let events = vec![
1901 Ok(SseEvent {
1902 event_type: None,
1903 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1904 .to_string(),
1905 id: None,
1906 }),
1907 Ok(SseEvent {
1908 event_type: None,
1909 data:
1910 r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1911 .to_string(),
1912 id: None,
1913 }),
1914 ];
1915
1916 let sse_stream = stream::iter(events);
1917 let token = CancellationToken::new();
1918 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1919
1920 let first_delta = stream.next().await.unwrap();
1921 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1922
1923 let second_delta = stream.next().await.unwrap();
1924 assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1925
1926 let complete = stream.next().await.unwrap();
1927 assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1928 }
1929
1930 #[tokio::test]
1931 async fn test_convert_gemini_stream_captures_final_usage() {
1932 use crate::api::provider::StreamChunk;
1933 use crate::api::sse::SseEvent;
1934 use futures::StreamExt;
1935 use futures::stream;
1936 use std::pin::pin;
1937 use tokio_util::sync::CancellationToken;
1938
1939 let events = vec![
1940 Ok(SseEvent {
1941 event_type: None,
1942 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1943 .to_string(),
1944 id: None,
1945 }),
1946 Ok(SseEvent {
1947 event_type: None,
1948 data: r#"{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}}"#
1949 .to_string(),
1950 id: None,
1951 }),
1952 ];
1953
1954 let sse_stream = stream::iter(events);
1955 let token = CancellationToken::new();
1956 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1957
1958 let first_delta = stream.next().await.unwrap();
1959 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1960
1961 let complete = stream.next().await.unwrap();
1962 if let StreamChunk::MessageComplete(response) = complete {
1963 assert_eq!(response.usage, Some(TokenUsage::new(10, 4, 14)));
1964 assert!(matches!(
1965 response.content.first(),
1966 Some(AssistantContent::Text { text }) if text == "Hello"
1967 ));
1968 } else {
1969 panic!("Expected MessageComplete");
1970 }
1971 }
1972
1973 #[tokio::test]
1974 async fn test_convert_gemini_stream_with_thinking() {
1975 use crate::api::provider::StreamChunk;
1976 use crate::api::sse::SseEvent;
1977 use futures::StreamExt;
1978 use futures::stream;
1979 use std::pin::pin;
1980 use tokio_util::sync::CancellationToken;
1981
1982 let events = vec![
1983 Ok(SseEvent {
1984 event_type: None,
1985 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1986 id: None,
1987 }),
1988 Ok(SseEvent {
1989 event_type: None,
1990 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1991 id: None,
1992 }),
1993 ];
1994
1995 let sse_stream = stream::iter(events);
1996 let token = CancellationToken::new();
1997 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1998
1999 let thinking_delta = stream.next().await.unwrap();
2000 assert!(
2001 matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
2002 );
2003
2004 let text_delta = stream.next().await.unwrap();
2005 assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
2006
2007 let complete = stream.next().await.unwrap();
2008 if let StreamChunk::MessageComplete(response) = complete {
2009 assert_eq!(response.content.len(), 2);
2010 assert!(matches!(
2011 &response.content[0],
2012 AssistantContent::Thought { .. }
2013 ));
2014 assert!(matches!(
2015 &response.content[1],
2016 AssistantContent::Text { .. }
2017 ));
2018 } else {
2019 panic!("Expected MessageComplete");
2020 }
2021 }
2022
2023 #[tokio::test]
2024 async fn test_convert_gemini_stream_with_function_call() {
2025 use crate::api::provider::StreamChunk;
2026 use crate::api::sse::SseEvent;
2027 use futures::StreamExt;
2028 use futures::stream;
2029 use std::pin::pin;
2030 use tokio_util::sync::CancellationToken;
2031
2032 let events = vec![
2033 Ok(SseEvent {
2034 event_type: None,
2035 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
2036 id: None,
2037 }),
2038 ];
2039
2040 let sse_stream = stream::iter(events);
2041 let token = CancellationToken::new();
2042 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2043
2044 let tool_start = stream.next().await.unwrap();
2045 assert!(
2046 matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
2047 );
2048
2049 let tool_input = stream.next().await.unwrap();
2050 assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
2051
2052 let complete = stream.next().await.unwrap();
2053 if let StreamChunk::MessageComplete(response) = complete {
2054 assert_eq!(response.content.len(), 1);
2055 if let AssistantContent::ToolCall {
2056 tool_call,
2057 thought_signature,
2058 } = &response.content[0]
2059 {
2060 assert_eq!(tool_call.name, "get_weather");
2061 assert_eq!(
2062 thought_signature.as_ref().map(|sig| sig.as_str()),
2063 Some("sig_123")
2064 );
2065 } else {
2066 panic!("Expected ToolCall");
2067 }
2068 } else {
2069 panic!("Expected MessageComplete");
2070 }
2071 }
2072
2073 #[tokio::test]
2074 async fn test_convert_gemini_stream_cancellation() {
2075 use crate::api::error::StreamError;
2076 use crate::api::provider::StreamChunk;
2077 use crate::api::sse::SseEvent;
2078 use futures::StreamExt;
2079 use futures::stream;
2080 use std::pin::pin;
2081 use tokio_util::sync::CancellationToken;
2082
2083 let events = vec![Ok(SseEvent {
2084 event_type: None,
2085 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
2086 .to_string(),
2087 id: None,
2088 })];
2089
2090 let sse_stream = stream::iter(events);
2091 let token = CancellationToken::new();
2092 token.cancel();
2093
2094 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2095
2096 let cancelled = stream.next().await.unwrap();
2097 assert!(matches!(
2098 cancelled,
2099 StreamChunk::Error(StreamError::Cancelled)
2100 ));
2101 }
2102
2103 #[tokio::test]
2104 #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2105 async fn test_stream_complete_real_api() {
2106 use crate::api::Provider;
2107 use crate::api::provider::StreamChunk;
2108 use crate::app::conversation::{Message, MessageData, UserContent};
2109 use futures::StreamExt;
2110 use tokio_util::sync::CancellationToken;
2111
2112 dotenvy::dotenv().ok();
2113 let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2114 let client = GeminiClient::new(api_key);
2115
2116 let message = Message {
2117 data: MessageData::User {
2118 content: vec![UserContent::Text {
2119 text: "Say exactly: Hello".to_string(),
2120 }],
2121 },
2122 timestamp: chrono::Utc::now().timestamp_millis() as u64,
2123 id: "test-msg".to_string(),
2124 parent_message_id: None,
2125 };
2126
2127 let model_id = ModelId::new(
2128 crate::config::provider::google(),
2129 "gemini-2.5-flash-preview-04-17",
2130 );
2131 let token = CancellationToken::new();
2132
2133 let mut stream = client
2134 .stream_complete(&model_id, vec![message], None, None, None, token)
2135 .await
2136 .expect("stream_complete should succeed");
2137
2138 let mut got_text_delta = false;
2139 let mut got_message_complete = false;
2140 let mut accumulated_text = String::new();
2141
2142 while let Some(chunk) = stream.next().await {
2143 match chunk {
2144 StreamChunk::TextDelta(text) => {
2145 got_text_delta = true;
2146 accumulated_text.push_str(&text);
2147 }
2148 StreamChunk::MessageComplete(response) => {
2149 got_message_complete = true;
2150 assert!(!response.content.is_empty());
2151 }
2152 StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2153 _ => {}
2154 }
2155 }
2156
2157 assert!(got_text_delta, "Should receive at least one TextDelta");
2158 assert!(
2159 got_message_complete,
2160 "Should receive MessageComplete at the end"
2161 );
2162 assert!(
2163 accumulated_text.to_lowercase().contains("hello"),
2164 "Response should contain 'hello', got: {accumulated_text}"
2165 );
2166 }
2167}