1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{
11 CompletionResponse, CompletionStream, Provider, StreamChunk, TokenUsage,
12};
13use crate::api::sse::parse_sse_stream;
14use crate::app::SystemContext;
15use crate::app::conversation::{
16 AssistantContent, ImageSource, Message as AppMessage, ThoughtContent, ThoughtSignature,
17 ToolResult, UserContent,
18};
19use crate::config::model::{ModelId, ModelParameters};
20use steer_tools::ToolSchema;
21
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
23
24#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
26 #[serde(rename = "mimeType")]
27 mime_type: String,
28 data: String, }
30
31#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
33 #[serde(rename = "mimeType")]
34 mime_type: String,
35 #[serde(rename = "fileUri")]
36 file_uri: String,
37}
38
39#[expect(dead_code)]
40#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiCodeExecutionResult {
42 outcome: String, }
45
46pub struct GeminiClient {
47 api_key: String,
48 client: HttpClient,
49}
50
51impl GeminiClient {
52 pub fn new(api_key: impl Into<String>) -> Self {
53 Self {
54 api_key: api_key.into(),
55 client: HttpClient::new(),
56 }
57 }
58}
59
60#[derive(Debug, Serialize)]
61struct GeminiRequest {
62 contents: Vec<GeminiContent>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 #[serde(rename = "systemInstruction")]
65 system_instruction: Option<GeminiSystemInstruction>,
66 #[serde(skip_serializing_if = "Option::is_none")]
67 tools: Option<Vec<GeminiTool>>,
68 #[serde(skip_serializing_if = "Option::is_none")]
69 #[serde(rename = "generationConfig")]
70 generation_config: Option<GeminiGenerationConfig>,
71}
72
73#[derive(Debug, Serialize, Default, Clone)]
74struct GeminiGenerationConfig {
75 #[serde(skip_serializing_if = "Option::is_none")]
76 #[serde(rename = "stopSequences")]
77 stop_sequences: Option<Vec<String>>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 #[serde(rename = "responseMimeType")]
80 response_mime_type: Option<GeminiMimeType>,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 #[serde(rename = "candidateCount")]
83 candidate_count: Option<i32>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 #[serde(rename = "maxOutputTokens")]
86 max_output_tokens: Option<i32>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 temperature: Option<f32>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 #[serde(rename = "topP")]
91 top_p: Option<f32>,
92 #[serde(skip_serializing_if = "Option::is_none")]
93 #[serde(rename = "topK")]
94 top_k: Option<i32>,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 #[serde(rename = "thinkingConfig")]
97 thinking_config: Option<GeminiThinkingConfig>,
98}
99
100#[derive(Debug, Serialize, Default, Clone)]
101struct GeminiThinkingConfig {
102 #[serde(skip_serializing_if = "Option::is_none")]
103 #[serde(rename = "includeThoughts")]
104 include_thoughts: Option<bool>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 #[serde(rename = "thinkingBudget")]
107 thinking_budget: Option<i32>,
108}
109
110#[expect(dead_code)]
111#[derive(Debug, Serialize, Clone)]
112#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
113enum GeminiMimeType {
114 MimeTypeUnspecified,
115 TextPlain,
116 ApplicationJson,
117}
118
119#[derive(Debug, Serialize)]
120struct GeminiSystemInstruction {
121 parts: Vec<GeminiRequestPart>,
122}
123
124#[derive(Debug, Serialize)]
125struct GeminiContent {
126 role: String,
127 parts: Vec<GeminiRequestPart>,
128}
129
130#[derive(Debug, Serialize)]
132#[serde(untagged)]
133enum GeminiRequestPart {
134 Text {
135 text: String,
136 },
137 #[serde(rename = "inlineData")]
138 InlineData {
139 #[serde(rename = "inlineData")]
140 inline_data: GeminiBlob,
141 },
142 #[serde(rename = "fileData")]
143 FileData {
144 #[serde(rename = "fileData")]
145 file_data: GeminiFileData,
146 },
147 #[serde(rename = "functionCall")]
148 FunctionCall {
149 #[serde(rename = "functionCall")]
150 function_call: GeminiFunctionCall, #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
152 thought_signature: Option<String>,
153 },
154 #[serde(rename = "functionResponse")]
155 FunctionResponse {
156 #[serde(rename = "functionResponse")]
157 function_response: GeminiFunctionResponse, },
159}
160
161#[derive(Debug, Deserialize)]
163#[serde(untagged)]
164enum GeminiResponsePartData {
165 Text {
166 text: String,
167 },
168 #[serde(rename = "inlineData")]
169 InlineData {
170 #[serde(rename = "inlineData")]
171 inline_data: GeminiBlob,
172 },
173 #[serde(rename = "functionCall")]
174 FunctionCall {
175 #[serde(rename = "functionCall")]
176 function_call: GeminiFunctionCall,
177 },
178 #[serde(rename = "fileData")]
179 FileData {
180 #[serde(rename = "fileData")]
181 file_data: GeminiFileData,
182 },
183 #[serde(rename = "executableCode")]
184 ExecutableCode {
185 #[serde(rename = "executableCode")]
186 executable_code: GeminiExecutableCode,
187 },
188 }
190
191#[derive(Debug, Deserialize)]
193struct GeminiResponsePart {
194 #[serde(default)] thought: bool,
196 #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
197 thought_signature: Option<String>,
198
199 #[serde(flatten)] data: GeminiResponsePartData,
201}
202
203#[derive(Debug, Serialize, Deserialize)]
204struct GeminiFunctionCall {
205 name: String,
206 args: Value,
207}
208
209#[derive(Debug, Serialize, PartialEq)]
210struct GeminiTool {
211 #[serde(rename = "functionDeclarations")]
212 function_declarations: Vec<GeminiFunctionDeclaration>,
213}
214
215#[derive(Debug, Serialize, PartialEq)]
216struct GeminiFunctionDeclaration {
217 name: String,
218 description: String,
219 parameters: GeminiParameterSchema,
220}
221
222#[derive(Debug, Serialize, PartialEq)]
223struct GeminiParameterSchema {
224 #[serde(rename = "type")]
225 schema_type: String, properties: serde_json::Map<String, Value>,
227 required: Vec<String>,
228}
229
230#[derive(Debug, Deserialize)]
231struct GeminiResponse {
232 #[serde(rename = "candidates")]
233 #[serde(skip_serializing_if = "Option::is_none")]
234 candidates: Option<Vec<GeminiCandidate>>,
235 #[serde(rename = "promptFeedback")]
236 #[serde(skip_serializing_if = "Option::is_none")]
237 prompt_feedback: Option<GeminiPromptFeedback>,
238 #[serde(rename = "usageMetadata")]
239 #[serde(skip_serializing_if = "Option::is_none")]
240 usage_metadata: Option<GeminiUsageMetadata>,
241}
242
243#[derive(Debug, Deserialize)]
244struct GeminiCandidate {
245 content: GeminiContentResponse,
246 #[serde(rename = "finishReason")]
247 #[serde(skip_serializing_if = "Option::is_none")]
248 finish_reason: Option<GeminiFinishReason>,
249 #[serde(rename = "safetyRatings")]
250 #[serde(skip_serializing_if = "Option::is_none")]
251 safety_ratings: Option<Vec<GeminiSafetyRating>>,
252 #[serde(rename = "citationMetadata")]
253 #[serde(skip_serializing_if = "Option::is_none")]
254 citation_metadata: Option<GeminiCitationMetadata>,
255}
256
257#[derive(Debug, Deserialize)]
258#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
259enum GeminiFinishReason {
260 FinishReasonUnspecified,
261 Stop,
262 MaxTokens,
263 Safety,
264 Recitation,
265 Other,
266 #[serde(rename = "TOOL_CODE_ERROR")]
267 ToolCodeError,
268 #[serde(rename = "TOOL_EXECUTION_HALT")]
269 ToolExecutionHalt,
270 MalformedFunctionCall,
271}
272
273#[derive(Debug, Deserialize)]
274struct GeminiPromptFeedback {
275 #[serde(rename = "blockReason")]
276 #[serde(skip_serializing_if = "Option::is_none")]
277 block_reason: Option<GeminiBlockReason>,
278 #[serde(rename = "safetyRatings")]
279 #[serde(skip_serializing_if = "Option::is_none")]
280 safety_ratings: Option<Vec<GeminiSafetyRating>>,
281}
282
283#[derive(Debug, Deserialize)]
284#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
285enum GeminiBlockReason {
286 BlockReasonUnspecified,
287 Safety,
288 Other,
289}
290
291#[derive(Debug, Deserialize)]
292#[expect(dead_code)]
293struct GeminiSafetyRating {
294 category: GeminiHarmCategory,
295 probability: GeminiHarmProbability,
296 #[serde(default)] blocked: bool,
298}
299
300#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
302#[expect(clippy::enum_variant_names)]
303enum GeminiHarmCategory {
304 HarmCategoryUnspecified,
305 HarmCategoryDerogatory,
306 HarmCategoryToxicity,
307 HarmCategoryViolence,
308 HarmCategorySexual,
309 HarmCategoryMedical,
310 HarmCategoryDangerous,
311 HarmCategoryHarassment,
312 HarmCategoryHateSpeech,
313 HarmCategorySexuallyExplicit,
314 HarmCategoryDangerousContent,
315 HarmCategoryCivicIntegrity,
316}
317
318#[derive(Debug, Deserialize)]
319#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
320enum GeminiHarmProbability {
321 HarmProbabilityUnspecified,
322 Negligible,
323 Low,
324 Medium,
325 High,
326}
327
328#[expect(dead_code)]
329#[derive(Debug, Deserialize)]
330struct GeminiCitationMetadata {
331 #[serde(rename = "citationSources")]
332 #[serde(skip_serializing_if = "Option::is_none")]
333 citation_sources: Option<Vec<GeminiCitationSource>>,
334}
335
336#[expect(dead_code)]
337#[derive(Debug, Deserialize)]
338struct GeminiCitationSource {
339 #[serde(rename = "startIndex")]
340 #[serde(skip_serializing_if = "Option::is_none")]
341 start_index: Option<i32>,
342 #[serde(rename = "endIndex")]
343 #[serde(skip_serializing_if = "Option::is_none")]
344 end_index: Option<i32>,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 uri: Option<String>,
347 #[serde(skip_serializing_if = "Option::is_none")]
348 license: Option<String>,
349}
350
351#[derive(Debug, Deserialize)]
352struct GeminiUsageMetadata {
353 #[serde(rename = "promptTokenCount")]
354 #[serde(skip_serializing_if = "Option::is_none")]
355 prompt: Option<i32>,
356 #[serde(rename = "candidatesTokenCount")]
357 #[serde(skip_serializing_if = "Option::is_none")]
358 candidates: Option<i32>,
359 #[serde(rename = "totalTokenCount")]
360 #[serde(skip_serializing_if = "Option::is_none")]
361 total: Option<i32>,
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365struct GeminiFunctionResponse {
366 name: String,
367 response: GeminiResponseContent,
368}
369
370#[derive(Debug, Serialize, Deserialize)]
371struct GeminiResponseContent {
372 content: Value,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376struct GeminiExecutableCode {
377 language: String, code: String,
379}
380
381#[derive(Debug, Deserialize)]
382#[expect(dead_code)]
383struct GeminiContentResponse {
384 role: String,
385 parts: Vec<GeminiResponsePart>,
386}
387
388fn map_gemini_usage(usage: &GeminiUsageMetadata) -> Option<TokenUsage> {
389 let prompt = usage.prompt.map(i64::from);
390 let candidates = usage.candidates.map(i64::from);
391 let total = usage.total.map(i64::from);
392
393 let input_tokens = prompt.or_else(|| match (total, candidates) {
394 (Some(total_tokens), Some(output_tokens)) if total_tokens >= output_tokens => {
395 Some(total_tokens - output_tokens)
396 }
397 _ => None,
398 })?;
399
400 let output_tokens = candidates.or_else(|| match (total, prompt) {
401 (Some(total_tokens), Some(input_tokens)) if total_tokens >= input_tokens => {
402 Some(total_tokens - input_tokens)
403 }
404 _ => None,
405 })?;
406
407 let total_tokens = total.or_else(|| input_tokens.checked_add(output_tokens))?;
408
409 Some(TokenUsage::new(
410 u32::try_from(input_tokens).ok()?,
411 u32::try_from(output_tokens).ok()?,
412 u32::try_from(total_tokens).ok()?,
413 ))
414}
415
416fn user_image_to_request_part(
417 image: &crate::app::conversation::ImageContent,
418) -> Result<GeminiRequestPart, ApiError> {
419 match &image.source {
420 ImageSource::DataUrl { data_url } => {
421 let Some((meta, payload)) = data_url.split_once(',') else {
422 return Err(ApiError::InvalidRequest {
423 provider: "google".to_string(),
424 details: "Image data URL is missing ',' separator".to_string(),
425 });
426 };
427
428 let Some(meta_body) = meta.strip_prefix("data:") else {
429 return Err(ApiError::InvalidRequest {
430 provider: "google".to_string(),
431 details: "Image data URL must start with 'data:'".to_string(),
432 });
433 };
434
435 if !meta_body
436 .split(';')
437 .any(|segment| segment.eq_ignore_ascii_case("base64"))
438 {
439 return Err(ApiError::InvalidRequest {
440 provider: "google".to_string(),
441 details: "Gemini image data URLs must be base64 encoded".to_string(),
442 });
443 }
444
445 let mime_type = meta_body
446 .split(';')
447 .next()
448 .filter(|value| !value.trim().is_empty())
449 .unwrap_or("application/octet-stream")
450 .to_string();
451
452 if !mime_type.starts_with("image/") {
453 return Err(ApiError::UnsupportedFeature {
454 provider: "google".to_string(),
455 feature: "image input mime type".to_string(),
456 details: format!("Unsupported image media type '{}'", mime_type),
457 });
458 }
459
460 Ok(GeminiRequestPart::InlineData {
461 inline_data: GeminiBlob {
462 mime_type,
463 data: payload.to_string(),
464 },
465 })
466 }
467 ImageSource::Url { url } => Ok(GeminiRequestPart::FileData {
468 file_data: GeminiFileData {
469 mime_type: image.mime_type.clone(),
470 file_uri: url.clone(),
471 },
472 }),
473 ImageSource::SessionFile { relative_path } => Err(ApiError::UnsupportedFeature {
474 provider: "google".to_string(),
475 feature: "image input source".to_string(),
476 details: format!(
477 "Gemini adapter cannot access session file '{}' directly; use data URLs or URL sources",
478 relative_path
479 ),
480 }),
481 }
482}
483
484fn convert_messages(messages: Vec<AppMessage>) -> Result<Vec<GeminiContent>, ApiError> {
485 let mut converted = Vec::new();
486
487 for msg in messages {
488 match &msg.data {
489 crate::app::conversation::MessageData::User { content, .. } => {
490 let mut parts = Vec::new();
491 for user_content in content {
492 match user_content {
493 UserContent::Text { text } => {
494 parts.push(GeminiRequestPart::Text { text: text.clone() });
495 }
496 UserContent::Image { image } => {
497 parts.push(user_image_to_request_part(image)?);
498 }
499 UserContent::CommandExecution {
500 command,
501 stdout,
502 stderr,
503 exit_code,
504 } => {
505 parts.push(GeminiRequestPart::Text {
506 text: UserContent::format_command_execution_as_xml(
507 command, stdout, stderr, *exit_code,
508 ),
509 });
510 }
511 }
512 }
513
514 if !parts.is_empty() {
515 converted.push(GeminiContent {
516 role: "user".to_string(),
517 parts,
518 });
519 }
520 }
521 crate::app::conversation::MessageData::Assistant { content, .. } => {
522 let mut parts = Vec::new();
523 for assistant_content in content {
524 match assistant_content {
525 AssistantContent::Text { text } => {
526 parts.push(GeminiRequestPart::Text { text: text.clone() });
527 }
528 AssistantContent::Image { image } => {
529 match user_image_to_request_part(image) {
530 Ok(part) => parts.push(part),
531 Err(err) => {
532 debug!(
533 target: "gemini::convert_messages",
534 "Skipping unsupported assistant image block: {}",
535 err
536 );
537 }
538 }
539 }
540 AssistantContent::ToolCall {
541 tool_call,
542 thought_signature,
543 } => parts.push(GeminiRequestPart::FunctionCall {
544 function_call: GeminiFunctionCall {
545 name: tool_call.name.clone(),
546 args: tool_call.parameters.clone(),
547 },
548 thought_signature: thought_signature
549 .as_ref()
550 .map(|signature| signature.as_str().to_string()),
551 }),
552 AssistantContent::Thought { .. } => {
553 }
555 }
556 }
557
558 converted.push(GeminiContent {
559 role: "model".to_string(),
560 parts,
561 });
562 }
563 crate::app::conversation::MessageData::Tool {
564 tool_use_id,
565 result,
566 ..
567 } => {
568 let result_value = match result {
569 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
570 _ => serde_json::to_value(result)
571 .unwrap_or_else(|_| Value::String(result.llm_format())),
572 };
573
574 let parts = vec![GeminiRequestPart::FunctionResponse {
575 function_response: GeminiFunctionResponse {
576 name: tool_use_id.clone(),
577 response: GeminiResponseContent {
578 content: result_value,
579 },
580 },
581 }];
582
583 converted.push(GeminiContent {
584 role: "function".to_string(),
585 parts,
586 });
587 }
588 }
589 }
590
591 Ok(converted)
592}
593
594fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
595 let reference = schema.get("$ref").and_then(|v| v.as_str())?;
596 let path = reference.strip_prefix("#/")?;
597 let mut current = root;
598 for segment in path.split('/') {
599 current = current.get(segment)?;
600 }
601 Some(current)
602}
603
604fn infer_type_from_enum(values: &[Value]) -> Option<String> {
605 let mut has_string = false;
606 let mut has_number = false;
607 let mut has_bool = false;
608 let mut has_object = false;
609 let mut has_array = false;
610
611 for value in values {
612 match value {
613 Value::String(_) => has_string = true,
614 Value::Number(_) => has_number = true,
615 Value::Bool(_) => has_bool = true,
616 Value::Object(_) => has_object = true,
617 Value::Array(_) => has_array = true,
618 Value::Null => {}
619 }
620 }
621
622 let kind_count = u8::from(has_string)
623 + u8::from(has_number)
624 + u8::from(has_bool)
625 + u8::from(has_object)
626 + u8::from(has_array);
627
628 if kind_count != 1 {
629 return None;
630 }
631
632 if has_string {
633 Some("string".to_string())
634 } else if has_number {
635 Some("number".to_string())
636 } else if has_bool {
637 Some("boolean".to_string())
638 } else if has_object {
639 Some("object".to_string())
640 } else if has_array {
641 Some("array".to_string())
642 } else {
643 None
644 }
645}
646
647fn normalize_type(value: &Value) -> Value {
648 if let Some(type_str) = value.as_str() {
649 return Value::String(type_str.to_string());
650 }
651
652 if let Some(type_array) = value.as_array()
653 && let Some(primary_type) = type_array
654 .iter()
655 .find_map(|v| if v.is_null() { None } else { v.as_str() })
656 {
657 return Value::String(primary_type.to_string());
658 }
659
660 Value::String("string".to_string())
661}
662
663fn extract_enum_values(value: &Value) -> Vec<Value> {
664 let Some(obj) = value.as_object() else {
665 return Vec::new();
666 };
667
668 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
669 return enum_values
670 .iter()
671 .filter(|v| !v.is_null())
672 .cloned()
673 .collect();
674 }
675
676 if let Some(const_value) = obj.get("const") {
677 if const_value.is_null() {
678 return Vec::new();
679 }
680 return vec![const_value.clone()];
681 }
682
683 Vec::new()
684}
685
686fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
687 match properties.get_mut(key) {
688 None => {
689 properties.insert(key.to_string(), value.clone());
690 }
691 Some(existing) => {
692 if existing == value {
693 return;
694 }
695
696 let existing_values = extract_enum_values(existing);
697 let incoming_values = extract_enum_values(value);
698 if incoming_values.is_empty() && existing_values.is_empty() {
699 return;
700 }
701
702 let mut combined = existing_values;
703 for item in incoming_values {
704 if !combined.contains(&item) {
705 combined.push(item);
706 }
707 }
708
709 if combined.is_empty() {
710 return;
711 }
712
713 if let Some(obj) = existing.as_object_mut() {
714 obj.remove("const");
715 obj.insert("enum".to_string(), Value::Array(combined.clone()));
716 if !obj.contains_key("type")
717 && let Some(inferred) = infer_type_from_enum(&combined)
718 {
719 obj.insert("type".to_string(), Value::String(inferred));
720 }
721 }
722 }
723 }
724}
725
726fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
727 let mut merged_props = serde_json::Map::new();
728 let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
729 let mut enum_values: Vec<Value> = Vec::new();
730 let mut type_candidates: Vec<String> = Vec::new();
731
732 for variant in variants {
733 let sanitized = sanitize_for_gemini(root, variant);
734
735 if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
736 type_candidates.push(schema_type.to_string());
737 }
738
739 if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
740 for (key, value) in props {
741 merge_property(&mut merged_props, key, value);
742 }
743 }
744
745 if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
746 let req_set: std::collections::BTreeSet<String> = req
747 .iter()
748 .filter_map(|item| item.as_str().map(|s| s.to_string()))
749 .collect();
750
751 required_intersection = match required_intersection.take() {
752 None => Some(req_set),
753 Some(existing) => Some(
754 existing
755 .intersection(&req_set)
756 .cloned()
757 .collect::<std::collections::BTreeSet<String>>(),
758 ),
759 };
760 }
761
762 if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
763 for value in values {
764 if value.is_null() {
765 continue;
766 }
767 if !enum_values.contains(value) {
768 enum_values.push(value.clone());
769 }
770 }
771 }
772 }
773
774 let schema_type = if !merged_props.is_empty() {
775 "object".to_string()
776 } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
777 inferred
778 } else if let Some(first) = type_candidates.first() {
779 first.clone()
780 } else {
781 "string".to_string()
782 };
783
784 let mut merged = serde_json::Map::new();
785 merged.insert("type".to_string(), Value::String(schema_type));
786
787 if !merged_props.is_empty() {
788 merged.insert("properties".to_string(), Value::Object(merged_props));
789 }
790
791 if let Some(required_set) = required_intersection
792 && !required_set.is_empty()
793 {
794 merged.insert(
795 "required".to_string(),
796 Value::Array(
797 required_set
798 .into_iter()
799 .map(Value::String)
800 .collect::<Vec<_>>(),
801 ),
802 );
803 }
804
805 if !enum_values.is_empty() {
806 merged.insert("enum".to_string(), Value::Array(enum_values));
807 }
808
809 Value::Object(merged)
810}
811
812fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
813 if let Some(resolved) = resolve_ref(root, schema) {
814 return sanitize_for_gemini(root, resolved);
815 }
816
817 let Some(obj) = schema.as_object() else {
818 return schema.clone();
819 };
820
821 if let Some(union) = obj
822 .get("oneOf")
823 .or_else(|| obj.get("anyOf"))
824 .or_else(|| obj.get("allOf"))
825 .and_then(|v| v.as_array())
826 {
827 return merge_union_schemas(root, union);
828 }
829
830 let mut out = serde_json::Map::new();
831 for (key, value) in obj {
832 match key.as_str() {
833 "$ref"
834 | "$defs"
835 | "oneOf"
836 | "anyOf"
837 | "allOf"
838 | "const"
839 | "additionalProperties"
840 | "default"
841 | "examples"
842 | "title"
843 | "pattern"
844 | "minLength"
845 | "maxLength"
846 | "minimum"
847 | "maximum"
848 | "minItems"
849 | "maxItems"
850 | "uniqueItems"
851 | "deprecated" => {}
852 "type" => {
853 out.insert("type".to_string(), normalize_type(value));
854 }
855 "properties" => {
856 if let Some(props) = value.as_object() {
857 let mut sanitized_props = serde_json::Map::new();
858 for (prop_key, prop_value) in props {
859 sanitized_props
860 .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
861 }
862 out.insert("properties".to_string(), Value::Object(sanitized_props));
863 }
864 }
865 "items" => {
866 out.insert("items".to_string(), sanitize_for_gemini(root, value));
867 }
868 "enum" => {
869 let values = value
870 .as_array()
871 .map(|items| {
872 items
873 .iter()
874 .filter(|v| !v.is_null())
875 .cloned()
876 .collect::<Vec<_>>()
877 })
878 .unwrap_or_default();
879 out.insert("enum".to_string(), Value::Array(values));
880 }
881 _ => {
882 out.insert(key.clone(), sanitize_for_gemini(root, value));
883 }
884 }
885 }
886
887 if let Some(const_value) = obj.get("const")
888 && !const_value.is_null()
889 {
890 out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
891 if !out.contains_key("type")
892 && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
893 {
894 out.insert("type".to_string(), Value::String(inferred));
895 }
896 }
897
898 if out.get("type") == Some(&Value::String("object".to_string()))
899 && !out.contains_key("properties")
900 {
901 out.insert(
902 "properties".to_string(),
903 Value::Object(serde_json::Map::new()),
904 );
905 }
906
907 if !out.contains_key("type") {
908 if out.contains_key("properties") {
909 out.insert("type".to_string(), Value::String("object".to_string()));
910 } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
911 && let Some(inferred) = infer_type_from_enum(enum_values)
912 {
913 out.insert("type".to_string(), Value::String(inferred));
914 }
915 }
916
917 Value::Object(out)
918}
919
920fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
921 if let Some(prop_map_orig) = property_value.as_object() {
922 let mut simplified_prop = prop_map_orig.clone();
923
924 if simplified_prop.remove("additionalProperties").is_some() {
926 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
927 }
928
929 if let Some(type_val) = simplified_prop.get_mut("type") {
931 if let Some(type_array) = type_val.as_array() {
932 if let Some(primary_type) = type_array
933 .iter()
934 .find_map(|v| if v.is_null() { None } else { v.as_str() })
935 {
936 *type_val = serde_json::Value::String(primary_type.to_string());
937 } else {
938 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
939 *type_val = serde_json::Value::String("string".to_string());
940 }
941 } else if !type_val.is_string() {
942 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
943 *type_val = serde_json::Value::String("string".to_string());
944 }
945 }
947
948 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
950 && let Some(format_val) = simplified_prop.get_mut("format")
951 && format_val.as_str() == Some("uint64")
952 {
953 *format_val = serde_json::Value::String("int64".to_string());
954 }
957
958 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
960 let should_remove_format = simplified_prop
961 .get("format")
962 .and_then(|f| f.as_str())
963 .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
964
965 if should_remove_format
966 && let Some(format_val) = simplified_prop.remove("format")
967 && let Some(format_str) = format_val.as_str()
968 {
969 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
970 }
971
972 if simplified_prop.remove("minLength").is_some() {
974 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
975 }
976 if simplified_prop.remove("maxLength").is_some() {
977 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
978 }
979 if simplified_prop.remove("pattern").is_some() {
980 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
981 }
982 }
983
984 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
986 && let Some(items_val) = simplified_prop.get_mut("items")
987 {
988 *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
989 }
990
991 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
993 && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
994 {
995 let simplified_nested_props: serde_json::Map<String, Value> = props
996 .iter()
997 .map(|(nested_key, nested_value)| {
998 (
999 nested_key.clone(),
1000 simplify_property_schema(
1001 &format!("{key}.{nested_key}"),
1002 tool_name,
1003 nested_value,
1004 ),
1005 )
1006 })
1007 .collect();
1008 *props = simplified_nested_props;
1009 }
1010
1011 serde_json::Value::Object(simplified_prop)
1012 } else {
1013 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
1014 property_value.clone() }
1016}
1017
1018fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
1019 let function_declarations = tools
1020 .into_iter()
1021 .map(|tool| {
1022 let root_schema = tool.input_schema.as_value();
1023 let summary = tool.input_schema.summary();
1024 let schema_type = if summary.schema_type.is_empty() {
1025 "object".to_string()
1026 } else {
1027 summary.schema_type
1028 };
1029
1030 let simplified_properties = summary
1032 .properties
1033 .iter()
1034 .map(|(key, value)| {
1035 let sanitized = sanitize_for_gemini(root_schema, value);
1036 (
1037 key.clone(),
1038 simplify_property_schema(key, &tool.name, &sanitized),
1039 )
1040 })
1041 .collect();
1042
1043 let parameters = GeminiParameterSchema {
1045 schema_type, properties: simplified_properties, required: summary.required, };
1049
1050 GeminiFunctionDeclaration {
1051 name: tool.name,
1052 description: tool.description,
1053 parameters,
1054 }
1055 })
1056 .collect();
1057
1058 vec![GeminiTool {
1059 function_declarations,
1060 }]
1061}
1062
1063fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
1064 if let Some(feedback) = &response.prompt_feedback
1066 && let Some(reason) = &feedback.block_reason
1067 {
1068 let details = format!(
1069 "Prompt blocked due to {:?}. Safety ratings: {:?}",
1070 reason, feedback.safety_ratings
1071 );
1072 warn!(target: "gemini::convert_response", "{}", details);
1073 return Err(ApiError::RequestBlocked {
1075 provider: "google".to_string(), details,
1077 });
1078 }
1079
1080 let candidates = if let Some(cands) = response.candidates {
1082 if cands.is_empty() {
1083 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
1086 return Err(ApiError::NoChoices {
1088 provider: "google".to_string(),
1089 });
1090 }
1091 cands } else {
1093 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
1094 return Err(ApiError::NoChoices {
1096 provider: "google".to_string(),
1097 });
1098 };
1099
1100 let candidate = &candidates[0];
1103
1104 if let Some(reason) = &candidate.finish_reason {
1106 match reason {
1107 GeminiFinishReason::Stop => { }
1108 GeminiFinishReason::MaxTokens => {
1109 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
1110 }
1111 GeminiFinishReason::Safety => {
1112 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
1113 }
1115 GeminiFinishReason::Recitation => {
1116 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
1117 }
1118 GeminiFinishReason::MalformedFunctionCall => {
1119 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
1120 }
1121 _ => {
1122 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
1123 }
1124 }
1125 }
1126
1127 if let Some(usage) = &response.usage_metadata {
1129 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1130 usage.prompt, usage.candidates, usage.total);
1131 }
1132 let usage = response.usage_metadata.as_ref().and_then(map_gemini_usage);
1133
1134 let content: Vec<AssistantContent> = candidate
1135 .content .parts .iter()
1138 .filter_map(|part| { if part.thought {
1141 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1142 if let GeminiResponsePartData::Text { text } = &part.data {
1144 Some(AssistantContent::Thought {
1145 thought: ThoughtContent::Simple {
1146 text: text.clone(),
1147 },
1148 })
1149 } else {
1150 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1151 None
1152 }
1153 } else {
1154 match &part.data {
1156 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1157 text: text.clone(),
1158 }),
1159 GeminiResponsePartData::InlineData { inline_data } => {
1160 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1161 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1162 }
1163 GeminiResponsePartData::FunctionCall { function_call } => {
1164 Some(AssistantContent::ToolCall {
1165 tool_call: steer_tools::ToolCall {
1166 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
1168 parameters: function_call.args.clone(),
1169 },
1170 thought_signature: part
1171 .thought_signature
1172 .clone()
1173 .map(ThoughtSignature::new),
1174 })
1175 }
1176 GeminiResponsePartData::FileData { file_data } => {
1177 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1178 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1179 }
1180 GeminiResponsePartData::ExecutableCode { executable_code } => {
1181 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1182 executable_code.language);
1183 Some(AssistantContent::Text {
1184 text: format!(
1185 "```{}
1186{}
1187```",
1188 executable_code.language.to_lowercase(),
1189 executable_code.code
1190 ),
1191 })
1192 }
1193 }
1194 }
1195 })
1196 .collect();
1197
1198 Ok(CompletionResponse { content, usage })
1199}
1200
1201#[async_trait]
1202impl Provider for GeminiClient {
1203 fn name(&self) -> &'static str {
1204 "google"
1205 }
1206
1207 async fn complete(
1208 &self,
1209 model_id: &ModelId,
1210 messages: Vec<AppMessage>,
1211 system: Option<SystemContext>,
1212 tools: Option<Vec<ToolSchema>>,
1213 _call_options: Option<ModelParameters>,
1214 token: CancellationToken,
1215 ) -> Result<CompletionResponse, ApiError> {
1216 let model_name = &model_id.id; let url = format!(
1218 "{}/models/{}:generateContent?key={}",
1219 GEMINI_API_BASE, model_name, self.api_key
1220 );
1221
1222 let gemini_contents = convert_messages(messages)?;
1223
1224 let system_instruction = system
1225 .and_then(|context| context.render())
1226 .map(|instructions| GeminiSystemInstruction {
1227 parts: vec![GeminiRequestPart::Text { text: instructions }],
1228 });
1229
1230 let gemini_tools = tools.map(convert_tools);
1231
1232 let (temperature, top_p, max_output_tokens) = {
1234 let opts = _call_options.as_ref();
1235 (
1236 opts.and_then(|o| o.temperature).or(Some(1.0)),
1237 opts.and_then(|o| o.top_p).or(Some(0.95)),
1238 opts.and_then(|o| o.max_tokens)
1239 .map(|v| v as i32)
1240 .or(Some(65536)),
1241 )
1242 };
1243 let thinking_config = _call_options
1244 .as_ref()
1245 .and_then(|o| o.thinking_config)
1246 .and_then(|tc| {
1247 if tc.enabled {
1248 Some(GeminiThinkingConfig {
1249 include_thoughts: tc.include_thoughts,
1250 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1251 })
1252 } else {
1253 None
1254 }
1255 });
1256
1257 let request = GeminiRequest {
1258 contents: gemini_contents,
1259 system_instruction,
1260 tools: gemini_tools,
1261 generation_config: Some(GeminiGenerationConfig {
1262 max_output_tokens,
1263 temperature,
1264 top_p,
1265 thinking_config,
1266 ..Default::default()
1267 }),
1268 };
1269
1270 let response = tokio::select! {
1271 biased;
1272 () = token.cancelled() => {
1273 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1274 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1275 }
1276 res = self.client.post(&url).json(&request).send() => {
1277 res.map_err(ApiError::Network)?
1278 }
1279 };
1280 let status = response.status();
1281
1282 if status != StatusCode::OK {
1283 let error_text = response.text().await.map_err(ApiError::Network)?;
1284 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1285 return Err(match status.as_u16() {
1286 401 | 403 => ApiError::AuthenticationFailed {
1287 provider: self.name().to_string(),
1288 details: error_text,
1289 },
1290 429 => ApiError::RateLimited {
1291 provider: self.name().to_string(),
1292 details: error_text,
1293 },
1294 400 | 404 => {
1295 error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1296 ApiError::InvalidRequest {
1297 provider: self.name().to_string(),
1298 details: error_text,
1299 }
1300 } 500..=599 => ApiError::ServerError {
1302 provider: self.name().to_string(),
1303 status_code: status.as_u16(),
1304 details: error_text,
1305 },
1306 _ => ApiError::Unknown {
1307 provider: self.name().to_string(),
1308 details: error_text,
1309 },
1310 });
1311 }
1312
1313 let response_text = response.text().await.map_err(ApiError::Network)?;
1314
1315 match serde_json::from_str::<GeminiResponse>(&response_text) {
1316 Ok(gemini_response) => {
1317 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1318 provider: self.name().to_string(),
1319 details: e.to_string(),
1320 })
1321 }
1322 Err(e) => {
1323 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1324 Err(ApiError::ResponseParsingError {
1325 provider: self.name().to_string(),
1326 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1327 })
1328 }
1329 }
1330 }
1331
1332 async fn stream_complete(
1333 &self,
1334 model_id: &ModelId,
1335 messages: Vec<AppMessage>,
1336 system: Option<SystemContext>,
1337 tools: Option<Vec<ToolSchema>>,
1338 _call_options: Option<ModelParameters>,
1339 token: CancellationToken,
1340 ) -> Result<CompletionStream, ApiError> {
1341 let model_name = &model_id.id;
1342 let url = format!(
1343 "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1344 GEMINI_API_BASE, model_name, self.api_key
1345 );
1346
1347 let gemini_contents = convert_messages(messages)?;
1348
1349 let system_instruction = system
1350 .and_then(|context| context.render())
1351 .map(|instructions| GeminiSystemInstruction {
1352 parts: vec![GeminiRequestPart::Text { text: instructions }],
1353 });
1354
1355 let gemini_tools = tools.map(convert_tools);
1356
1357 let (temperature, top_p, max_output_tokens) = {
1358 let opts = _call_options.as_ref();
1359 (
1360 opts.and_then(|o| o.temperature).or(Some(1.0)),
1361 opts.and_then(|o| o.top_p).or(Some(0.95)),
1362 opts.and_then(|o| o.max_tokens)
1363 .map(|v| v as i32)
1364 .or(Some(65536)),
1365 )
1366 };
1367 let thinking_config = _call_options
1368 .as_ref()
1369 .and_then(|o| o.thinking_config)
1370 .and_then(|tc| {
1371 if tc.enabled {
1372 Some(GeminiThinkingConfig {
1373 include_thoughts: tc.include_thoughts,
1374 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1375 })
1376 } else {
1377 None
1378 }
1379 });
1380
1381 let request = GeminiRequest {
1382 contents: gemini_contents,
1383 system_instruction,
1384 tools: gemini_tools,
1385 generation_config: Some(GeminiGenerationConfig {
1386 max_output_tokens,
1387 temperature,
1388 top_p,
1389 thinking_config,
1390 ..Default::default()
1391 }),
1392 };
1393
1394 let response = tokio::select! {
1395 biased;
1396 () = token.cancelled() => {
1397 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1398 }
1399 res = self.client.post(&url).json(&request).send() => {
1400 res.map_err(ApiError::Network)?
1401 }
1402 };
1403
1404 let status = response.status();
1405 if status != StatusCode::OK {
1406 let error_text = response.text().await.map_err(ApiError::Network)?;
1407 error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1408 return Err(match status.as_u16() {
1409 401 | 403 => ApiError::AuthenticationFailed {
1410 provider: self.name().to_string(),
1411 details: error_text,
1412 },
1413 429 => ApiError::RateLimited {
1414 provider: self.name().to_string(),
1415 details: error_text,
1416 },
1417 400 | 404 => ApiError::InvalidRequest {
1418 provider: self.name().to_string(),
1419 details: error_text,
1420 },
1421 500..=599 => ApiError::ServerError {
1422 provider: self.name().to_string(),
1423 status_code: status.as_u16(),
1424 details: error_text,
1425 },
1426 _ => ApiError::Unknown {
1427 provider: self.name().to_string(),
1428 details: error_text,
1429 },
1430 });
1431 }
1432
1433 let byte_stream = response.bytes_stream();
1434 let sse_stream = parse_sse_stream(byte_stream);
1435
1436 Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1437 }
1438}
1439
1440impl GeminiClient {
1441 fn convert_gemini_stream(
1442 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1443 + Unpin
1444 + Send
1445 + 'static,
1446 token: CancellationToken,
1447 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1448 async_stream::stream! {
1449 let mut content: Vec<AssistantContent> = Vec::new();
1450 let mut latest_usage: Option<TokenUsage> = None;
1451 loop {
1452 if token.is_cancelled() {
1453 yield StreamChunk::Error(StreamError::Cancelled);
1454 break;
1455 }
1456
1457 let event_result = tokio::select! {
1458 biased;
1459 () = token.cancelled() => {
1460 yield StreamChunk::Error(StreamError::Cancelled);
1461 break;
1462 }
1463 event = sse_stream.next() => event
1464 };
1465
1466 let Some(event_result) = event_result else {
1467 let content = std::mem::take(&mut content);
1468 yield StreamChunk::MessageComplete(CompletionResponse {
1469 content,
1470 usage: latest_usage,
1471 });
1472 break;
1473 };
1474
1475 let event = match event_result {
1476 Ok(e) => e,
1477 Err(e) => {
1478 yield StreamChunk::Error(StreamError::SseParse(e));
1479 break;
1480 }
1481 };
1482
1483 let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1484 Ok(c) => c,
1485 Err(e) => {
1486 debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1487 continue;
1488 }
1489 };
1490
1491 if let Some(usage) = chunk.usage_metadata.as_ref()
1492 && let Some(mapped) = map_gemini_usage(usage) {
1493 latest_usage = Some(mapped);
1494 }
1495
1496 if let Some(candidates) = chunk.candidates {
1497 for candidate in candidates {
1498 for part in candidate.content.parts {
1499 let GeminiResponsePart {
1500 thought,
1501 thought_signature,
1502 data,
1503 } = part;
1504 let thought_signature =
1505 thought_signature.map(ThoughtSignature::new);
1506
1507 if thought {
1508 if let GeminiResponsePartData::Text { text } = data {
1509 match content.last_mut() {
1510 Some(AssistantContent::Thought {
1511 thought: ThoughtContent::Simple { text: buf },
1512 }) => buf.push_str(&text),
1513 _ => {
1514 content.push(AssistantContent::Thought {
1515 thought: ThoughtContent::Simple { text: text.clone() },
1516 });
1517 }
1518 }
1519 yield StreamChunk::ThinkingDelta(text);
1520 }
1521 } else {
1522 match data {
1523 GeminiResponsePartData::Text { text } => {
1524 match content.last_mut() {
1525 Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1526 _ => {
1527 content.push(AssistantContent::Text { text: text.clone() });
1528 }
1529 }
1530 yield StreamChunk::TextDelta(text);
1531 }
1532 GeminiResponsePartData::FunctionCall { function_call } => {
1533 let id = uuid::Uuid::new_v4().to_string();
1534 content.push(AssistantContent::ToolCall {
1535 tool_call: steer_tools::ToolCall {
1536 id: id.clone(),
1537 name: function_call.name.clone(),
1538 parameters: function_call.args.clone(),
1539 },
1540 thought_signature,
1541 });
1542 yield StreamChunk::ToolUseStart {
1543 id: id.clone(),
1544 name: function_call.name,
1545 };
1546 yield StreamChunk::ToolUseInputDelta {
1547 id,
1548 delta: function_call.args.to_string(),
1549 };
1550 }
1551 _ => {}
1552 }
1553 }
1554 }
1555 }
1556 }
1557 }
1558 }
1559 }
1560}
1561
1562#[cfg(test)]
1563mod tests {
1564 use super::*;
1565 use crate::app::conversation::{ImageContent, ImageSource, Message, MessageData, UserContent};
1566 use serde_json::json;
1567
1568 #[test]
1569 fn test_simplify_property_schema_removes_additional_properties() {
1570 let property_value = json!({
1571 "type": "object",
1572 "properties": {
1573 "name": {"type": "string"}
1574 },
1575 "additionalProperties": false
1576 });
1577
1578 let expected = json!({
1579 "type": "object",
1580 "properties": {
1581 "name": {"type": "string"}
1582 }
1583 });
1584
1585 let result = simplify_property_schema("testProp", "testTool", &property_value);
1586 assert_eq!(result, expected);
1587 }
1588
1589 #[test]
1590 fn test_simplify_property_schema_removes_unsupported_string_formats() {
1591 let property_value = json!({
1592 "type": "string",
1593 "format": "uri",
1594 "minLength": 1,
1595 "maxLength": 100,
1596 "pattern": "^https://"
1597 });
1598
1599 let expected = json!({
1600 "type": "string"
1601 });
1602
1603 let result = simplify_property_schema("urlProp", "testTool", &property_value);
1604 assert_eq!(result, expected);
1605 }
1606
1607 #[test]
1608 fn test_simplify_property_schema_keeps_supported_string_formats() {
1609 let property_value = json!({
1610 "type": "string",
1611 "format": "date-time"
1612 });
1613
1614 let expected = json!({
1615 "type": "string",
1616 "format": "date-time"
1617 });
1618
1619 let result = simplify_property_schema("dateProp", "testTool", &property_value);
1620 assert_eq!(result, expected);
1621 }
1622
1623 #[test]
1624 fn test_simplify_property_schema_handles_array_types() {
1625 let property_value = json!({
1626 "type": ["string", "null"],
1627 "format": "email"
1628 });
1629
1630 let expected = json!({
1631 "type": "string"
1632 });
1633
1634 let result = simplify_property_schema("emailProp", "testTool", &property_value);
1635 assert_eq!(result, expected);
1636 }
1637
1638 #[test]
1639 fn test_simplify_property_schema_recursively_handles_array_items() {
1640 let property_value = json!({
1641 "type": "array",
1642 "items": {
1643 "type": "object",
1644 "properties": {
1645 "url": {
1646 "type": "string",
1647 "format": "uri"
1648 }
1649 },
1650 "additionalProperties": false
1651 }
1652 });
1653
1654 let expected = json!({
1655 "type": "array",
1656 "items": {
1657 "type": "object",
1658 "properties": {
1659 "url": {
1660 "type": "string"
1661 }
1662 }
1663 }
1664 });
1665
1666 let result = simplify_property_schema("linksProp", "testTool", &property_value);
1667 assert_eq!(result, expected);
1668 }
1669
1670 #[test]
1671 fn test_simplify_property_schema_recursively_handles_nested_objects() {
1672 let property_value = json!({
1673 "type": "object",
1674 "properties": {
1675 "nested": {
1676 "type": "object",
1677 "properties": {
1678 "field": {
1679 "type": "string",
1680 "format": "hostname"
1681 }
1682 },
1683 "additionalProperties": true
1684 }
1685 },
1686 "additionalProperties": false
1687 });
1688
1689 let expected = json!({
1690 "type": "object",
1691 "properties": {
1692 "nested": {
1693 "type": "object",
1694 "properties": {
1695 "field": {
1696 "type": "string"
1697 }
1698 }
1699 }
1700 }
1701 });
1702
1703 let result = simplify_property_schema("complexProp", "testTool", &property_value);
1704 assert_eq!(result, expected);
1705 }
1706
1707 #[test]
1708 fn test_simplify_property_schema_fixes_uint64_format() {
1709 let property_value = json!({
1710 "type": "integer",
1711 "format": "uint64"
1712 });
1713
1714 let expected = json!({
1715 "type": "integer",
1716 "format": "int64"
1717 });
1718
1719 let result = simplify_property_schema("idProp", "testTool", &property_value);
1720 assert_eq!(result, expected);
1721 }
1722
1723 #[test]
1724 fn test_convert_tools_integration() {
1725 use steer_tools::{InputSchema, ToolSchema};
1726
1727 let tool = ToolSchema {
1728 name: "create_issue".to_string(),
1729 display_name: "Create Issue".to_string(),
1730 description: "Create an issue".to_string(),
1731 input_schema: InputSchema::object(
1732 {
1733 let mut props = serde_json::Map::new();
1734 props.insert(
1735 "title".to_string(),
1736 json!({
1737 "type": "string",
1738 "minLength": 1
1739 }),
1740 );
1741 props.insert(
1742 "links".to_string(),
1743 json!({
1744 "type": "array",
1745 "items": {
1746 "type": "object",
1747 "properties": {
1748 "url": {
1749 "type": "string",
1750 "format": "uri"
1751 }
1752 },
1753 "additionalProperties": false
1754 }
1755 }),
1756 );
1757 props
1758 },
1759 vec!["title".to_string()],
1760 ),
1761 };
1762
1763 let expected_tools = vec![GeminiTool {
1764 function_declarations: vec![GeminiFunctionDeclaration {
1765 name: "create_issue".to_string(),
1766 description: "Create an issue".to_string(),
1767 parameters: GeminiParameterSchema {
1768 schema_type: "object".to_string(),
1769 properties: {
1770 let mut props = serde_json::Map::new();
1771 props.insert(
1772 "title".to_string(),
1773 json!({
1774 "type": "string"
1775 }),
1776 );
1777 props.insert(
1778 "links".to_string(),
1779 json!({
1780 "type": "array",
1781 "items": {
1782 "type": "object",
1783 "properties": {
1784 "url": {
1785 "type": "string"
1786 }
1787 }
1788 }
1789 }),
1790 );
1791 props
1792 },
1793 required: vec!["title".to_string()],
1794 },
1795 }],
1796 }];
1797
1798 let result = convert_tools(vec![tool]);
1799 assert_eq!(result, expected_tools);
1800 }
1801
1802 #[test]
1803 fn test_convert_messages_includes_data_url_image_as_inline_data() {
1804 let messages = vec![Message {
1805 data: MessageData::User {
1806 content: vec![UserContent::Image {
1807 image: ImageContent {
1808 mime_type: "image/png".to_string(),
1809 source: ImageSource::DataUrl {
1810 data_url: "data:image/png;base64,Zm9v".to_string(),
1811 },
1812 width: None,
1813 height: None,
1814 bytes: None,
1815 sha256: None,
1816 },
1817 }],
1818 },
1819 timestamp: 1,
1820 id: "msg-image".to_string(),
1821 parent_message_id: None,
1822 }];
1823
1824 let converted = convert_messages(messages).expect("convert messages");
1825 assert_eq!(converted.len(), 1);
1826 let first = &converted[0];
1827 assert_eq!(first.role, "user");
1828 assert_eq!(first.parts.len(), 1);
1829
1830 let json = serde_json::to_value(&first.parts[0]).expect("serialize part");
1831 assert_eq!(json["inlineData"]["mimeType"], "image/png");
1832 assert_eq!(json["inlineData"]["data"], "Zm9v");
1833 }
1834
1835 #[test]
1836 fn test_convert_messages_rejects_session_file_image_source() {
1837 let messages = vec![Message {
1838 data: MessageData::User {
1839 content: vec![UserContent::Image {
1840 image: ImageContent {
1841 mime_type: "image/png".to_string(),
1842 source: ImageSource::SessionFile {
1843 relative_path: "session-1/image.png".to_string(),
1844 },
1845 width: None,
1846 height: None,
1847 bytes: None,
1848 sha256: None,
1849 },
1850 }],
1851 },
1852 timestamp: 1,
1853 id: "msg-image".to_string(),
1854 parent_message_id: None,
1855 }];
1856
1857 let err = convert_messages(messages).expect_err("expected unsupported feature");
1858 match err {
1859 ApiError::UnsupportedFeature {
1860 provider,
1861 feature,
1862 details,
1863 } => {
1864 assert_eq!(provider, "google");
1865 assert_eq!(feature, "image input source");
1866 assert!(details.contains("session file"));
1867 }
1868 other => panic!("Expected UnsupportedFeature, got {other:?}"),
1869 }
1870 }
1871
1872 #[test]
1873 fn test_map_gemini_usage_derives_missing_counter() {
1874 let usage = GeminiUsageMetadata {
1875 prompt: None,
1876 candidates: Some(7),
1877 total: Some(19),
1878 };
1879
1880 assert_eq!(map_gemini_usage(&usage), Some(TokenUsage::new(12, 7, 19)));
1881 }
1882
1883 #[test]
1884 fn test_map_gemini_usage_returns_none_when_counters_not_derivable() {
1885 let usage = GeminiUsageMetadata {
1886 prompt: Some(9),
1887 candidates: None,
1888 total: None,
1889 };
1890
1891 assert_eq!(map_gemini_usage(&usage), None);
1892 }
1893
1894 #[test]
1895 fn test_convert_response_maps_usage_metadata() {
1896 let response = GeminiResponse {
1897 candidates: Some(vec![GeminiCandidate {
1898 content: GeminiContentResponse {
1899 role: "model".to_string(),
1900 parts: vec![GeminiResponsePart {
1901 thought: false,
1902 thought_signature: None,
1903 data: GeminiResponsePartData::Text {
1904 text: "Hello".to_string(),
1905 },
1906 }],
1907 },
1908 finish_reason: Some(GeminiFinishReason::Stop),
1909 safety_ratings: None,
1910 citation_metadata: None,
1911 }]),
1912 prompt_feedback: None,
1913 usage_metadata: Some(GeminiUsageMetadata {
1914 prompt: Some(11),
1915 candidates: Some(5),
1916 total: Some(16),
1917 }),
1918 };
1919
1920 let converted = convert_response(response).expect("response should convert");
1921
1922 assert_eq!(converted.usage, Some(TokenUsage::new(11, 5, 16)));
1923 assert!(matches!(
1924 converted.content.first(),
1925 Some(AssistantContent::Text { text }) if text == "Hello"
1926 ));
1927 }
1928
1929 #[tokio::test]
1930 async fn test_convert_gemini_stream_text_deltas() {
1931 use crate::api::provider::StreamChunk;
1932 use crate::api::sse::SseEvent;
1933 use futures::StreamExt;
1934 use futures::stream;
1935 use std::pin::pin;
1936 use tokio_util::sync::CancellationToken;
1937
1938 let events = vec![
1939 Ok(SseEvent {
1940 event_type: None,
1941 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1942 .to_string(),
1943 id: None,
1944 }),
1945 Ok(SseEvent {
1946 event_type: None,
1947 data:
1948 r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1949 .to_string(),
1950 id: None,
1951 }),
1952 ];
1953
1954 let sse_stream = stream::iter(events);
1955 let token = CancellationToken::new();
1956 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1957
1958 let first_delta = stream.next().await.unwrap();
1959 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1960
1961 let second_delta = stream.next().await.unwrap();
1962 assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1963
1964 let complete = stream.next().await.unwrap();
1965 assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1966 }
1967
1968 #[tokio::test]
1969 async fn test_convert_gemini_stream_captures_final_usage() {
1970 use crate::api::provider::StreamChunk;
1971 use crate::api::sse::SseEvent;
1972 use futures::StreamExt;
1973 use futures::stream;
1974 use std::pin::pin;
1975 use tokio_util::sync::CancellationToken;
1976
1977 let events = vec![
1978 Ok(SseEvent {
1979 event_type: None,
1980 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1981 .to_string(),
1982 id: None,
1983 }),
1984 Ok(SseEvent {
1985 event_type: None,
1986 data: r#"{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":4,"totalTokenCount":14}}"#
1987 .to_string(),
1988 id: None,
1989 }),
1990 ];
1991
1992 let sse_stream = stream::iter(events);
1993 let token = CancellationToken::new();
1994 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1995
1996 let first_delta = stream.next().await.unwrap();
1997 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1998
1999 let complete = stream.next().await.unwrap();
2000 if let StreamChunk::MessageComplete(response) = complete {
2001 assert_eq!(response.usage, Some(TokenUsage::new(10, 4, 14)));
2002 assert!(matches!(
2003 response.content.first(),
2004 Some(AssistantContent::Text { text }) if text == "Hello"
2005 ));
2006 } else {
2007 panic!("Expected MessageComplete");
2008 }
2009 }
2010
2011 #[tokio::test]
2012 async fn test_convert_gemini_stream_with_thinking() {
2013 use crate::api::provider::StreamChunk;
2014 use crate::api::sse::SseEvent;
2015 use futures::StreamExt;
2016 use futures::stream;
2017 use std::pin::pin;
2018 use tokio_util::sync::CancellationToken;
2019
2020 let events = vec![
2021 Ok(SseEvent {
2022 event_type: None,
2023 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
2024 id: None,
2025 }),
2026 Ok(SseEvent {
2027 event_type: None,
2028 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
2029 id: None,
2030 }),
2031 ];
2032
2033 let sse_stream = stream::iter(events);
2034 let token = CancellationToken::new();
2035 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2036
2037 let thinking_delta = stream.next().await.unwrap();
2038 assert!(
2039 matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
2040 );
2041
2042 let text_delta = stream.next().await.unwrap();
2043 assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
2044
2045 let complete = stream.next().await.unwrap();
2046 if let StreamChunk::MessageComplete(response) = complete {
2047 assert_eq!(response.content.len(), 2);
2048 assert!(matches!(
2049 &response.content[0],
2050 AssistantContent::Thought { .. }
2051 ));
2052 assert!(matches!(
2053 &response.content[1],
2054 AssistantContent::Text { .. }
2055 ));
2056 } else {
2057 panic!("Expected MessageComplete");
2058 }
2059 }
2060
2061 #[tokio::test]
2062 async fn test_convert_gemini_stream_with_function_call() {
2063 use crate::api::provider::StreamChunk;
2064 use crate::api::sse::SseEvent;
2065 use futures::StreamExt;
2066 use futures::stream;
2067 use std::pin::pin;
2068 use tokio_util::sync::CancellationToken;
2069
2070 let events = vec![
2071 Ok(SseEvent {
2072 event_type: None,
2073 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
2074 id: None,
2075 }),
2076 ];
2077
2078 let sse_stream = stream::iter(events);
2079 let token = CancellationToken::new();
2080 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2081
2082 let tool_start = stream.next().await.unwrap();
2083 assert!(
2084 matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
2085 );
2086
2087 let tool_input = stream.next().await.unwrap();
2088 assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
2089
2090 let complete = stream.next().await.unwrap();
2091 if let StreamChunk::MessageComplete(response) = complete {
2092 assert_eq!(response.content.len(), 1);
2093 if let AssistantContent::ToolCall {
2094 tool_call,
2095 thought_signature,
2096 } = &response.content[0]
2097 {
2098 assert_eq!(tool_call.name, "get_weather");
2099 assert_eq!(
2100 thought_signature.as_ref().map(|sig| sig.as_str()),
2101 Some("sig_123")
2102 );
2103 } else {
2104 panic!("Expected ToolCall");
2105 }
2106 } else {
2107 panic!("Expected MessageComplete");
2108 }
2109 }
2110
2111 #[tokio::test]
2112 async fn test_convert_gemini_stream_cancellation() {
2113 use crate::api::error::StreamError;
2114 use crate::api::provider::StreamChunk;
2115 use crate::api::sse::SseEvent;
2116 use futures::StreamExt;
2117 use futures::stream;
2118 use std::pin::pin;
2119 use tokio_util::sync::CancellationToken;
2120
2121 let events = vec![Ok(SseEvent {
2122 event_type: None,
2123 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
2124 .to_string(),
2125 id: None,
2126 })];
2127
2128 let sse_stream = stream::iter(events);
2129 let token = CancellationToken::new();
2130 token.cancel();
2131
2132 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
2133
2134 let cancelled = stream.next().await.unwrap();
2135 assert!(matches!(
2136 cancelled,
2137 StreamChunk::Error(StreamError::Cancelled)
2138 ));
2139 }
2140
2141 #[tokio::test]
2142 #[ignore = "Requires GOOGLE_API_KEY environment variable"]
2143 async fn test_stream_complete_real_api() {
2144 use crate::api::Provider;
2145 use crate::api::provider::StreamChunk;
2146 use crate::app::conversation::{Message, MessageData, UserContent};
2147 use futures::StreamExt;
2148 use tokio_util::sync::CancellationToken;
2149
2150 dotenvy::dotenv().ok();
2151 let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
2152 let client = GeminiClient::new(api_key);
2153
2154 let message = Message {
2155 data: MessageData::User {
2156 content: vec![UserContent::Text {
2157 text: "Say exactly: Hello".to_string(),
2158 }],
2159 },
2160 timestamp: chrono::Utc::now().timestamp_millis() as u64,
2161 id: "test-msg".to_string(),
2162 parent_message_id: None,
2163 };
2164
2165 let model_id = ModelId::new(
2166 crate::config::provider::google(),
2167 "gemini-2.5-flash-preview-04-17",
2168 );
2169 let token = CancellationToken::new();
2170
2171 let mut stream = client
2172 .stream_complete(&model_id, vec![message], None, None, None, token)
2173 .await
2174 .expect("stream_complete should succeed");
2175
2176 let mut got_text_delta = false;
2177 let mut got_message_complete = false;
2178 let mut accumulated_text = String::new();
2179
2180 while let Some(chunk) = stream.next().await {
2181 match chunk {
2182 StreamChunk::TextDelta(text) => {
2183 got_text_delta = true;
2184 accumulated_text.push_str(&text);
2185 }
2186 StreamChunk::MessageComplete(response) => {
2187 got_message_complete = true;
2188 assert!(!response.content.is_empty());
2189 }
2190 StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
2191 _ => {}
2192 }
2193 }
2194
2195 assert!(got_text_delta, "Should receive at least one TextDelta");
2196 assert!(
2197 got_message_complete,
2198 "Should receive MessageComplete at the end"
2199 );
2200 assert!(
2201 accumulated_text.to_lowercase().contains("hello"),
2202 "Response should contain 'hello', got: {accumulated_text}"
2203 );
2204 }
2205}