1use async_trait::async_trait;
2use futures::StreamExt;
3use reqwest::{Client as HttpClient, StatusCode};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio_util::sync::CancellationToken;
7use tracing::{debug, error, info, warn};
8
9use crate::api::error::{ApiError, SseParseError, StreamError};
10use crate::api::provider::{CompletionResponse, CompletionStream, Provider, StreamChunk};
11use crate::api::sse::parse_sse_stream;
12use crate::app::SystemContext;
13use crate::app::conversation::{
14 AssistantContent, Message as AppMessage, ThoughtContent, ThoughtSignature, ToolResult,
15 UserContent,
16};
17use crate::config::model::{ModelId, ModelParameters};
18use steer_tools::ToolSchema;
19
20const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
21
22#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiBlob {
24 #[serde(rename = "mimeType")]
25 mime_type: String,
26 data: String, }
28
29#[derive(Debug, Deserialize, Serialize, Clone)] struct GeminiFileData {
31 #[serde(rename = "mimeType")]
32 mime_type: String,
33 #[serde(rename = "fileUri")]
34 file_uri: String,
35}
36
37pub struct GeminiClient {
38 api_key: String,
39 client: HttpClient,
40}
41
42impl GeminiClient {
43 pub fn new(api_key: impl Into<String>) -> Self {
44 Self {
45 api_key: api_key.into(),
46 client: HttpClient::new(),
47 }
48 }
49}
50
51#[derive(Debug, Serialize)]
52struct GeminiRequest {
53 contents: Vec<GeminiContent>,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 #[serde(rename = "systemInstruction")]
56 system_instruction: Option<GeminiSystemInstruction>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 tools: Option<Vec<GeminiTool>>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 #[serde(rename = "generationConfig")]
61 generation_config: Option<GeminiGenerationConfig>,
62}
63
64#[derive(Debug, Serialize, Default, Clone)]
65struct GeminiGenerationConfig {
66 #[serde(skip_serializing_if = "Option::is_none")]
67 #[serde(rename = "stopSequences")]
68 stop_sequences: Option<Vec<String>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 #[serde(rename = "responseMimeType")]
71 response_mime_type: Option<GeminiMimeType>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 #[serde(rename = "candidateCount")]
74 candidate_count: Option<i32>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 #[serde(rename = "maxOutputTokens")]
77 max_output_tokens: Option<i32>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 temperature: Option<f32>,
80 #[serde(skip_serializing_if = "Option::is_none")]
81 #[serde(rename = "topP")]
82 top_p: Option<f32>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 #[serde(rename = "topK")]
85 top_k: Option<i32>,
86 #[serde(skip_serializing_if = "Option::is_none")]
87 #[serde(rename = "thinkingConfig")]
88 thinking_config: Option<GeminiThinkingConfig>,
89}
90
91#[derive(Debug, Serialize, Default, Clone)]
92struct GeminiThinkingConfig {
93 #[serde(skip_serializing_if = "Option::is_none")]
94 #[serde(rename = "includeThoughts")]
95 include_thoughts: Option<bool>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 #[serde(rename = "thinkingBudget")]
98 thinking_budget: Option<i32>,
99}
100
101#[expect(dead_code)]
102#[derive(Debug, Serialize, Clone)]
103#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
104enum GeminiMimeType {
105 MimeTypeUnspecified,
106 TextPlain,
107 ApplicationJson,
108}
109
110#[derive(Debug, Serialize)]
111struct GeminiSystemInstruction {
112 parts: Vec<GeminiRequestPart>,
113}
114
115#[derive(Debug, Serialize)]
116struct GeminiContent {
117 role: String,
118 parts: Vec<GeminiRequestPart>,
119}
120
121#[derive(Debug, Serialize)]
123#[serde(untagged)]
124enum GeminiRequestPart {
125 Text {
126 text: String,
127 },
128 #[serde(rename = "functionCall")]
129 FunctionCall {
130 #[serde(rename = "functionCall")]
131 function_call: GeminiFunctionCall, #[serde(rename = "thoughtSignature", skip_serializing_if = "Option::is_none")]
133 thought_signature: Option<String>,
134 },
135 #[serde(rename = "functionResponse")]
136 FunctionResponse {
137 #[serde(rename = "functionResponse")]
138 function_response: GeminiFunctionResponse, },
140}
141
142#[derive(Debug, Deserialize)]
144#[serde(untagged)]
145enum GeminiResponsePartData {
146 Text {
147 text: String,
148 },
149 #[serde(rename = "inlineData")]
150 InlineData {
151 #[serde(rename = "inlineData")]
152 inline_data: GeminiBlob,
153 },
154 #[serde(rename = "functionCall")]
155 FunctionCall {
156 #[serde(rename = "functionCall")]
157 function_call: GeminiFunctionCall,
158 },
159 #[serde(rename = "fileData")]
160 FileData {
161 #[serde(rename = "fileData")]
162 file_data: GeminiFileData,
163 },
164 #[serde(rename = "executableCode")]
165 ExecutableCode {
166 #[serde(rename = "executableCode")]
167 executable_code: GeminiExecutableCode,
168 },
169 }
171
172#[derive(Debug, Deserialize)]
174struct GeminiResponsePart {
175 #[serde(default)] thought: bool,
177 #[serde(default, rename = "thoughtSignature", alias = "thought_signature")]
178 thought_signature: Option<String>,
179
180 #[serde(flatten)] data: GeminiResponsePartData,
182}
183
184#[derive(Debug, Serialize, Deserialize)]
185struct GeminiFunctionCall {
186 name: String,
187 args: Value,
188}
189
190#[derive(Debug, Serialize, PartialEq)]
191struct GeminiTool {
192 #[serde(rename = "functionDeclarations")]
193 function_declarations: Vec<GeminiFunctionDeclaration>,
194}
195
196#[derive(Debug, Serialize, PartialEq)]
197struct GeminiFunctionDeclaration {
198 name: String,
199 description: String,
200 parameters: GeminiParameterSchema,
201}
202
203#[derive(Debug, Serialize, PartialEq)]
204struct GeminiParameterSchema {
205 #[serde(rename = "type")]
206 schema_type: String, properties: serde_json::Map<String, Value>,
208 required: Vec<String>,
209}
210
211#[derive(Debug, Deserialize)]
212struct GeminiResponse {
213 #[serde(rename = "candidates")]
214 #[serde(skip_serializing_if = "Option::is_none")]
215 candidates: Option<Vec<GeminiCandidate>>,
216 #[serde(rename = "promptFeedback")]
217 #[serde(skip_serializing_if = "Option::is_none")]
218 prompt_feedback: Option<GeminiPromptFeedback>,
219 #[serde(rename = "usageMetadata")]
220 #[serde(skip_serializing_if = "Option::is_none")]
221 usage_metadata: Option<GeminiUsageMetadata>,
222}
223
224#[derive(Debug, Deserialize)]
225struct GeminiCandidate {
226 content: GeminiContentResponse,
227 #[serde(rename = "finishReason")]
228 #[serde(skip_serializing_if = "Option::is_none")]
229 finish_reason: Option<GeminiFinishReason>,
230 #[serde(rename = "safetyRatings")]
231 #[serde(skip_serializing_if = "Option::is_none")]
232 safety_ratings: Option<Vec<GeminiSafetyRating>>,
233 #[serde(rename = "citationMetadata")]
234 #[serde(skip_serializing_if = "Option::is_none")]
235 citation_metadata: Option<GeminiCitationMetadata>,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
240enum GeminiFinishReason {
241 FinishReasonUnspecified,
242 Stop,
243 MaxTokens,
244 Safety,
245 Recitation,
246 Other,
247 #[serde(rename = "TOOL_CODE_ERROR")]
248 ToolCodeError,
249 #[serde(rename = "TOOL_EXECUTION_HALT")]
250 ToolExecutionHalt,
251 MalformedFunctionCall,
252}
253
254#[derive(Debug, Deserialize)]
255struct GeminiPromptFeedback {
256 #[serde(rename = "blockReason")]
257 #[serde(skip_serializing_if = "Option::is_none")]
258 block_reason: Option<GeminiBlockReason>,
259 #[serde(rename = "safetyRatings")]
260 #[serde(skip_serializing_if = "Option::is_none")]
261 safety_ratings: Option<Vec<GeminiSafetyRating>>,
262}
263
264#[derive(Debug, Deserialize)]
265#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
266enum GeminiBlockReason {
267 BlockReasonUnspecified,
268 Safety,
269 Other,
270}
271
272#[derive(Debug, Deserialize)]
273#[expect(dead_code)]
274struct GeminiSafetyRating {
275 category: GeminiHarmCategory,
276 probability: GeminiHarmProbability,
277 #[serde(default)] blocked: bool,
279}
280
281#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
283#[expect(clippy::enum_variant_names)]
284enum GeminiHarmCategory {
285 HarmCategoryUnspecified,
286 HarmCategoryDerogatory,
287 HarmCategoryToxicity,
288 HarmCategoryViolence,
289 HarmCategorySexual,
290 HarmCategoryMedical,
291 HarmCategoryDangerous,
292 HarmCategoryHarassment,
293 HarmCategoryHateSpeech,
294 HarmCategorySexuallyExplicit,
295 HarmCategoryDangerousContent,
296 HarmCategoryCivicIntegrity,
297}
298
299#[derive(Debug, Deserialize)]
300#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
301enum GeminiHarmProbability {
302 HarmProbabilityUnspecified,
303 Negligible,
304 Low,
305 Medium,
306 High,
307}
308
309#[expect(dead_code)]
310#[derive(Debug, Deserialize)]
311struct GeminiCitationMetadata {
312 #[serde(rename = "citationSources")]
313 #[serde(skip_serializing_if = "Option::is_none")]
314 citation_sources: Option<Vec<GeminiCitationSource>>,
315}
316
317#[expect(dead_code)]
318#[derive(Debug, Deserialize)]
319struct GeminiCitationSource {
320 #[serde(rename = "startIndex")]
321 #[serde(skip_serializing_if = "Option::is_none")]
322 start_index: Option<i32>,
323 #[serde(rename = "endIndex")]
324 #[serde(skip_serializing_if = "Option::is_none")]
325 end_index: Option<i32>,
326 #[serde(skip_serializing_if = "Option::is_none")]
327 uri: Option<String>,
328 #[serde(skip_serializing_if = "Option::is_none")]
329 license: Option<String>,
330}
331
332#[derive(Debug, Deserialize)]
333struct GeminiUsageMetadata {
334 #[serde(rename = "promptTokenCount")]
335 #[serde(skip_serializing_if = "Option::is_none")]
336 prompt: Option<i32>,
337 #[serde(rename = "candidatesTokenCount")]
338 #[serde(skip_serializing_if = "Option::is_none")]
339 candidates: Option<i32>,
340 #[serde(rename = "totalTokenCount")]
341 #[serde(skip_serializing_if = "Option::is_none")]
342 total: Option<i32>,
343}
344
345#[derive(Debug, Serialize, Deserialize)]
346struct GeminiFunctionResponse {
347 name: String,
348 response: GeminiResponseContent,
349}
350
351#[derive(Debug, Serialize, Deserialize)]
352struct GeminiResponseContent {
353 content: Value,
354}
355
356#[derive(Debug, Serialize, Deserialize)]
357struct GeminiExecutableCode {
358 language: String, code: String,
360}
361
362#[derive(Debug, Deserialize)]
363#[expect(dead_code)]
364struct GeminiContentResponse {
365 role: String,
366 parts: Vec<GeminiResponsePart>,
367}
368
369fn convert_messages(messages: Vec<AppMessage>) -> Vec<GeminiContent> {
370 messages
371 .into_iter()
372 .filter_map(|msg| match &msg.data {
373 crate::app::conversation::MessageData::User { content, .. } => {
374 let parts: Vec<GeminiRequestPart> = content
375 .iter()
376 .map(|user_content| match user_content {
377 UserContent::Text { text } => {
378 GeminiRequestPart::Text { text: text.clone() }
379 }
380 UserContent::CommandExecution {
381 command,
382 stdout,
383 stderr,
384 exit_code,
385 } => GeminiRequestPart::Text {
386 text: UserContent::format_command_execution_as_xml(
387 command, stdout, stderr, *exit_code,
388 ),
389 },
390 })
391 .collect();
392
393 if parts.is_empty() {
395 None
396 } else {
397 Some(GeminiContent {
398 role: "user".to_string(),
399 parts,
400 })
401 }
402 }
403 crate::app::conversation::MessageData::Assistant { content, .. } => {
404 let parts: Vec<GeminiRequestPart> = content
405 .iter()
406 .filter_map(|assistant_content| match assistant_content {
407 AssistantContent::Text { text } => {
408 Some(GeminiRequestPart::Text { text: text.clone() })
409 }
410 AssistantContent::ToolCall {
411 tool_call,
412 thought_signature,
413 } => Some(GeminiRequestPart::FunctionCall {
414 function_call: GeminiFunctionCall {
415 name: tool_call.name.clone(),
416 args: tool_call.parameters.clone(),
417 },
418 thought_signature: thought_signature
419 .as_ref()
420 .map(|signature| signature.as_str().to_string()),
421 }),
422 AssistantContent::Thought { .. } => {
423 None
425 }
426 })
427 .collect();
428
429 Some(GeminiContent {
431 role: "model".to_string(),
432 parts,
433 })
434 }
435 crate::app::conversation::MessageData::Tool {
436 tool_use_id,
437 result,
438 ..
439 } => {
440 let result_value = match result {
442 ToolResult::Error(e) => Value::String(format!("Error: {e}")),
443 _ => {
444 serde_json::to_value(result)
446 .unwrap_or_else(|_| Value::String(result.llm_format()))
447 }
448 };
449
450 let parts = vec![GeminiRequestPart::FunctionResponse {
451 function_response: GeminiFunctionResponse {
452 name: tool_use_id.clone(), response: GeminiResponseContent {
454 content: result_value,
455 },
456 },
457 }];
458
459 Some(GeminiContent {
460 role: "function".to_string(),
461 parts,
462 })
463 }
464 })
465 .collect()
466}
467
468fn resolve_ref<'a>(root: &'a Value, schema: &'a Value) -> Option<&'a Value> {
469 let reference = schema.get("$ref").and_then(|v| v.as_str())?;
470 let path = reference.strip_prefix("#/")?;
471 let mut current = root;
472 for segment in path.split('/') {
473 current = current.get(segment)?;
474 }
475 Some(current)
476}
477
478fn infer_type_from_enum(values: &[Value]) -> Option<String> {
479 let mut has_string = false;
480 let mut has_number = false;
481 let mut has_bool = false;
482 let mut has_object = false;
483 let mut has_array = false;
484
485 for value in values {
486 match value {
487 Value::String(_) => has_string = true,
488 Value::Number(_) => has_number = true,
489 Value::Bool(_) => has_bool = true,
490 Value::Object(_) => has_object = true,
491 Value::Array(_) => has_array = true,
492 Value::Null => {}
493 }
494 }
495
496 let kind_count = u8::from(has_string)
497 + u8::from(has_number)
498 + u8::from(has_bool)
499 + u8::from(has_object)
500 + u8::from(has_array);
501
502 if kind_count != 1 {
503 return None;
504 }
505
506 if has_string {
507 Some("string".to_string())
508 } else if has_number {
509 Some("number".to_string())
510 } else if has_bool {
511 Some("boolean".to_string())
512 } else if has_object {
513 Some("object".to_string())
514 } else if has_array {
515 Some("array".to_string())
516 } else {
517 None
518 }
519}
520
521fn normalize_type(value: &Value) -> Value {
522 if let Some(type_str) = value.as_str() {
523 return Value::String(type_str.to_string());
524 }
525
526 if let Some(type_array) = value.as_array()
527 && let Some(primary_type) = type_array
528 .iter()
529 .find_map(|v| if v.is_null() { None } else { v.as_str() })
530 {
531 return Value::String(primary_type.to_string());
532 }
533
534 Value::String("string".to_string())
535}
536
537fn extract_enum_values(value: &Value) -> Vec<Value> {
538 let Some(obj) = value.as_object() else {
539 return Vec::new();
540 };
541
542 if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
543 return enum_values
544 .iter()
545 .filter(|v| !v.is_null())
546 .cloned()
547 .collect();
548 }
549
550 if let Some(const_value) = obj.get("const") {
551 if const_value.is_null() {
552 return Vec::new();
553 }
554 return vec![const_value.clone()];
555 }
556
557 Vec::new()
558}
559
560fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
561 match properties.get_mut(key) {
562 None => {
563 properties.insert(key.to_string(), value.clone());
564 }
565 Some(existing) => {
566 if existing == value {
567 return;
568 }
569
570 let existing_values = extract_enum_values(existing);
571 let incoming_values = extract_enum_values(value);
572 if incoming_values.is_empty() && existing_values.is_empty() {
573 return;
574 }
575
576 let mut combined = existing_values;
577 for item in incoming_values {
578 if !combined.contains(&item) {
579 combined.push(item);
580 }
581 }
582
583 if combined.is_empty() {
584 return;
585 }
586
587 if let Some(obj) = existing.as_object_mut() {
588 obj.remove("const");
589 obj.insert("enum".to_string(), Value::Array(combined.clone()));
590 if !obj.contains_key("type")
591 && let Some(inferred) = infer_type_from_enum(&combined)
592 {
593 obj.insert("type".to_string(), Value::String(inferred));
594 }
595 }
596 }
597 }
598}
599
600fn merge_union_schemas(root: &Value, variants: &[Value]) -> Value {
601 let mut merged_props = serde_json::Map::new();
602 let mut required_intersection: Option<std::collections::BTreeSet<String>> = None;
603 let mut enum_values: Vec<Value> = Vec::new();
604 let mut type_candidates: Vec<String> = Vec::new();
605
606 for variant in variants {
607 let sanitized = sanitize_for_gemini(root, variant);
608
609 if let Some(schema_type) = sanitized.get("type").and_then(|v| v.as_str()) {
610 type_candidates.push(schema_type.to_string());
611 }
612
613 if let Some(props) = sanitized.get("properties").and_then(|v| v.as_object()) {
614 for (key, value) in props {
615 merge_property(&mut merged_props, key, value);
616 }
617 }
618
619 if let Some(req) = sanitized.get("required").and_then(|v| v.as_array()) {
620 let req_set: std::collections::BTreeSet<String> = req
621 .iter()
622 .filter_map(|item| item.as_str().map(|s| s.to_string()))
623 .collect();
624
625 required_intersection = match required_intersection.take() {
626 None => Some(req_set),
627 Some(existing) => Some(
628 existing
629 .intersection(&req_set)
630 .cloned()
631 .collect::<std::collections::BTreeSet<String>>(),
632 ),
633 };
634 }
635
636 if let Some(values) = sanitized.get("enum").and_then(|v| v.as_array()) {
637 for value in values {
638 if value.is_null() {
639 continue;
640 }
641 if !enum_values.contains(value) {
642 enum_values.push(value.clone());
643 }
644 }
645 }
646 }
647
648 let schema_type = if !merged_props.is_empty() {
649 "object".to_string()
650 } else if let Some(inferred) = infer_type_from_enum(&enum_values) {
651 inferred
652 } else if let Some(first) = type_candidates.first() {
653 first.clone()
654 } else {
655 "string".to_string()
656 };
657
658 let mut merged = serde_json::Map::new();
659 merged.insert("type".to_string(), Value::String(schema_type));
660
661 if !merged_props.is_empty() {
662 merged.insert("properties".to_string(), Value::Object(merged_props));
663 }
664
665 if let Some(required_set) = required_intersection
666 && !required_set.is_empty()
667 {
668 merged.insert(
669 "required".to_string(),
670 Value::Array(
671 required_set
672 .into_iter()
673 .map(Value::String)
674 .collect::<Vec<_>>(),
675 ),
676 );
677 }
678
679 if !enum_values.is_empty() {
680 merged.insert("enum".to_string(), Value::Array(enum_values));
681 }
682
683 Value::Object(merged)
684}
685
686fn sanitize_for_gemini(root: &Value, schema: &Value) -> Value {
687 if let Some(resolved) = resolve_ref(root, schema) {
688 return sanitize_for_gemini(root, resolved);
689 }
690
691 let Some(obj) = schema.as_object() else {
692 return schema.clone();
693 };
694
695 if let Some(union) = obj
696 .get("oneOf")
697 .or_else(|| obj.get("anyOf"))
698 .or_else(|| obj.get("allOf"))
699 .and_then(|v| v.as_array())
700 {
701 return merge_union_schemas(root, union);
702 }
703
704 let mut out = serde_json::Map::new();
705 for (key, value) in obj {
706 match key.as_str() {
707 "$ref"
708 | "$defs"
709 | "oneOf"
710 | "anyOf"
711 | "allOf"
712 | "const"
713 | "additionalProperties"
714 | "default"
715 | "examples"
716 | "title"
717 | "pattern"
718 | "minLength"
719 | "maxLength"
720 | "minimum"
721 | "maximum"
722 | "minItems"
723 | "maxItems"
724 | "uniqueItems"
725 | "deprecated" => {}
726 "type" => {
727 out.insert("type".to_string(), normalize_type(value));
728 }
729 "properties" => {
730 if let Some(props) = value.as_object() {
731 let mut sanitized_props = serde_json::Map::new();
732 for (prop_key, prop_value) in props {
733 sanitized_props
734 .insert(prop_key.clone(), sanitize_for_gemini(root, prop_value));
735 }
736 out.insert("properties".to_string(), Value::Object(sanitized_props));
737 }
738 }
739 "items" => {
740 out.insert("items".to_string(), sanitize_for_gemini(root, value));
741 }
742 "enum" => {
743 let values = value
744 .as_array()
745 .map(|items| {
746 items
747 .iter()
748 .filter(|v| !v.is_null())
749 .cloned()
750 .collect::<Vec<_>>()
751 })
752 .unwrap_or_default();
753 out.insert("enum".to_string(), Value::Array(values));
754 }
755 _ => {
756 out.insert(key.clone(), sanitize_for_gemini(root, value));
757 }
758 }
759 }
760
761 if let Some(const_value) = obj.get("const")
762 && !const_value.is_null()
763 {
764 out.insert("enum".to_string(), Value::Array(vec![const_value.clone()]));
765 if !out.contains_key("type")
766 && let Some(inferred) = infer_type_from_enum(std::slice::from_ref(const_value))
767 {
768 out.insert("type".to_string(), Value::String(inferred));
769 }
770 }
771
772 if out.get("type") == Some(&Value::String("object".to_string()))
773 && !out.contains_key("properties")
774 {
775 out.insert(
776 "properties".to_string(),
777 Value::Object(serde_json::Map::new()),
778 );
779 }
780
781 if !out.contains_key("type") {
782 if out.contains_key("properties") {
783 out.insert("type".to_string(), Value::String("object".to_string()));
784 } else if let Some(enum_values) = out.get("enum").and_then(|v| v.as_array())
785 && let Some(inferred) = infer_type_from_enum(enum_values)
786 {
787 out.insert("type".to_string(), Value::String(inferred));
788 }
789 }
790
791 Value::Object(out)
792}
793
794fn simplify_property_schema(key: &str, tool_name: &str, property_value: &Value) -> Value {
795 if let Some(prop_map_orig) = property_value.as_object() {
796 let mut simplified_prop = prop_map_orig.clone();
797
798 if simplified_prop.remove("additionalProperties").is_some() {
800 debug!(target: "gemini::simplify_property_schema", "Removed 'additionalProperties' from property '{}' in tool '{}'", key, tool_name);
801 }
802
803 if let Some(type_val) = simplified_prop.get_mut("type") {
805 if let Some(type_array) = type_val.as_array() {
806 if let Some(primary_type) = type_array
807 .iter()
808 .find_map(|v| if v.is_null() { None } else { v.as_str() })
809 {
810 *type_val = serde_json::Value::String(primary_type.to_string());
811 } else {
812 warn!(target: "gemini::simplify_property_schema", "Could not determine primary type for property '{}' in tool '{}', defaulting to string.", key, tool_name);
813 *type_val = serde_json::Value::String("string".to_string());
814 }
815 } else if !type_val.is_string() {
816 warn!(target: "gemini::simplify_property_schema", "Unexpected 'type' format for property '{}' in tool '{}': {:?}. Defaulting to string.", key, tool_name, type_val);
817 *type_val = serde_json::Value::String("string".to_string());
818 }
819 }
821
822 if simplified_prop.get("type") == Some(&serde_json::Value::String("integer".to_string()))
824 && let Some(format_val) = simplified_prop.get_mut("format")
825 && format_val.as_str() == Some("uint64")
826 {
827 *format_val = serde_json::Value::String("int64".to_string());
828 }
831
832 if simplified_prop.get("type") == Some(&serde_json::Value::String("string".to_string())) {
834 let should_remove_format = simplified_prop
835 .get("format")
836 .and_then(|f| f.as_str())
837 .is_some_and(|format_str| format_str != "enum" && format_str != "date-time");
838
839 if should_remove_format
840 && let Some(format_val) = simplified_prop.remove("format")
841 && let Some(format_str) = format_val.as_str()
842 {
843 debug!(target: "gemini::simplify_property_schema", "Removed unsupported format '{}' from string property '{}' in tool '{}'", format_str, key, tool_name);
844 }
845
846 if simplified_prop.remove("minLength").is_some() {
848 debug!(target: "gemini::simplify_property_schema", "Removed 'minLength' from string property '{}' in tool '{}'", key, tool_name);
849 }
850 if simplified_prop.remove("maxLength").is_some() {
851 debug!(target: "gemini::simplify_property_schema", "Removed 'maxLength' from string property '{}' in tool '{}'", key, tool_name);
852 }
853 if simplified_prop.remove("pattern").is_some() {
854 debug!(target: "gemini::simplify_property_schema", "Removed 'pattern' from string property '{}' in tool '{}'", key, tool_name);
855 }
856 }
857
858 if simplified_prop.get("type") == Some(&serde_json::Value::String("array".to_string()))
860 && let Some(items_val) = simplified_prop.get_mut("items")
861 {
862 *items_val = simplify_property_schema(&format!("{key}.items"), tool_name, items_val);
863 }
864
865 if simplified_prop.get("type") == Some(&serde_json::Value::String("object".to_string()))
867 && let Some(Value::Object(props)) = simplified_prop.get_mut("properties")
868 {
869 let simplified_nested_props: serde_json::Map<String, Value> = props
870 .iter()
871 .map(|(nested_key, nested_value)| {
872 (
873 nested_key.clone(),
874 simplify_property_schema(
875 &format!("{key}.{nested_key}"),
876 tool_name,
877 nested_value,
878 ),
879 )
880 })
881 .collect();
882 *props = simplified_nested_props;
883 }
884
885 serde_json::Value::Object(simplified_prop)
886 } else {
887 warn!(target: "gemini::simplify_property_schema", "Property value for '{}' in tool '{}' is not an object: {:?}. Using original value.", key, tool_name, property_value);
888 property_value.clone() }
890}
891
892fn convert_tools(tools: Vec<ToolSchema>) -> Vec<GeminiTool> {
893 let function_declarations = tools
894 .into_iter()
895 .map(|tool| {
896 let root_schema = tool.input_schema.as_value();
897 let summary = tool.input_schema.summary();
898 let schema_type = if summary.schema_type.is_empty() {
899 "object".to_string()
900 } else {
901 summary.schema_type
902 };
903
904 let simplified_properties = summary
906 .properties
907 .iter()
908 .map(|(key, value)| {
909 let sanitized = sanitize_for_gemini(root_schema, value);
910 (
911 key.clone(),
912 simplify_property_schema(key, &tool.name, &sanitized),
913 )
914 })
915 .collect();
916
917 let parameters = GeminiParameterSchema {
919 schema_type, properties: simplified_properties, required: summary.required, };
923
924 GeminiFunctionDeclaration {
925 name: tool.name,
926 description: tool.description,
927 parameters,
928 }
929 })
930 .collect();
931
932 vec![GeminiTool {
933 function_declarations,
934 }]
935}
936
937fn convert_response(response: GeminiResponse) -> Result<CompletionResponse, ApiError> {
938 if let Some(feedback) = &response.prompt_feedback
940 && let Some(reason) = &feedback.block_reason
941 {
942 let details = format!(
943 "Prompt blocked due to {:?}. Safety ratings: {:?}",
944 reason, feedback.safety_ratings
945 );
946 warn!(target: "gemini::convert_response", "{}", details);
947 return Err(ApiError::RequestBlocked {
949 provider: "google".to_string(), details,
951 });
952 }
953
954 let candidates = if let Some(cands) = response.candidates {
956 if cands.is_empty() {
957 warn!(target: "gemini::convert_response", "No candidates received, and prompt was not blocked.");
960 return Err(ApiError::NoChoices {
962 provider: "google".to_string(),
963 });
964 }
965 cands } else {
967 warn!(target: "gemini::convert_response", "No candidates field in Gemini response.");
968 return Err(ApiError::NoChoices {
970 provider: "google".to_string(),
971 });
972 };
973
974 let candidate = &candidates[0];
977
978 if let Some(reason) = &candidate.finish_reason {
980 match reason {
981 GeminiFinishReason::Stop => { }
982 GeminiFinishReason::MaxTokens => {
983 warn!(target: "gemini::convert_response", "Response stopped due to MaxTokens limit.");
984 }
985 GeminiFinishReason::Safety => {
986 warn!(target: "gemini::convert_response", "Response stopped due to safety settings. Ratings: {:?}", candidate.safety_ratings);
987 }
989 GeminiFinishReason::Recitation => {
990 warn!(target: "gemini::convert_response", "Response stopped due to potential recitation. Citations: {:?}", candidate.citation_metadata);
991 }
992 GeminiFinishReason::MalformedFunctionCall => {
993 warn!(target: "gemini::convert_response", "Response stopped due to malformed function call.");
994 }
995 _ => {
996 info!(target: "gemini::convert_response", "Response finished with reason: {:?}", reason);
997 }
998 }
999 }
1000
1001 if let Some(usage) = &response.usage_metadata {
1003 debug!(target: "gemini::convert_response", "Usage - Prompt Tokens: {:?}, Candidates Tokens: {:?}, Total Tokens: {:?}",
1004 usage.prompt, usage.candidates, usage.total);
1005 }
1006
1007 let content: Vec<AssistantContent> = candidate
1008 .content .parts .iter()
1011 .filter_map(|part| { if part.thought {
1014 debug!(target: "gemini::convert_response", "Received thought part: {:?}", part);
1015 if let GeminiResponsePartData::Text { text } = &part.data {
1017 Some(AssistantContent::Thought {
1018 thought: ThoughtContent::Simple {
1019 text: text.clone(),
1020 },
1021 })
1022 } else {
1023 warn!(target: "gemini::convert_response", "Thought part contains non-text data: {:?}", part.data);
1024 None
1025 }
1026 } else {
1027 match &part.data {
1029 GeminiResponsePartData::Text { text } => Some(AssistantContent::Text {
1030 text: text.clone(),
1031 }),
1032 GeminiResponsePartData::InlineData { inline_data } => {
1033 warn!(target: "gemini::convert_response", "Received InlineData part (MIME type: {}). Converting to placeholder text.", inline_data.mime_type);
1034 Some(AssistantContent::Text { text: format!("[Inline Data: {}]", inline_data.mime_type) })
1035 }
1036 GeminiResponsePartData::FunctionCall { function_call } => {
1037 Some(AssistantContent::ToolCall {
1038 tool_call: steer_tools::ToolCall {
1039 id: uuid::Uuid::new_v4().to_string(), name: function_call.name.clone(),
1041 parameters: function_call.args.clone(),
1042 },
1043 thought_signature: part
1044 .thought_signature
1045 .clone()
1046 .map(ThoughtSignature::new),
1047 })
1048 }
1049 GeminiResponsePartData::FileData { file_data } => {
1050 warn!(target: "gemini::convert_response", "Received FileData part (URI: {}). Converting to placeholder text.", file_data.file_uri);
1051 Some(AssistantContent::Text { text: format!("[File Data: {}]", file_data.file_uri) })
1052 }
1053 GeminiResponsePartData::ExecutableCode { executable_code } => {
1054 info!(target: "gemini::convert_response", "Received ExecutableCode part ({}). Converting to text.",
1055 executable_code.language);
1056 Some(AssistantContent::Text {
1057 text: format!(
1058 "```{}
1059{}
1060```",
1061 executable_code.language.to_lowercase(),
1062 executable_code.code
1063 ),
1064 })
1065 }
1066 }
1067 }
1068 })
1069 .collect();
1070
1071 Ok(CompletionResponse { content })
1072}
1073
1074#[async_trait]
1075impl Provider for GeminiClient {
1076 fn name(&self) -> &'static str {
1077 "google"
1078 }
1079
1080 async fn complete(
1081 &self,
1082 model_id: &ModelId,
1083 messages: Vec<AppMessage>,
1084 system: Option<SystemContext>,
1085 tools: Option<Vec<ToolSchema>>,
1086 _call_options: Option<ModelParameters>,
1087 token: CancellationToken,
1088 ) -> Result<CompletionResponse, ApiError> {
1089 let model_name = &model_id.id; let url = format!(
1091 "{}/models/{}:generateContent?key={}",
1092 GEMINI_API_BASE, model_name, self.api_key
1093 );
1094
1095 let gemini_contents = convert_messages(messages);
1096
1097 let system_instruction = system
1098 .and_then(|context| context.render())
1099 .map(|instructions| GeminiSystemInstruction {
1100 parts: vec![GeminiRequestPart::Text { text: instructions }],
1101 });
1102
1103 let gemini_tools = tools.map(convert_tools);
1104
1105 let (temperature, top_p, max_output_tokens) = {
1107 let opts = _call_options.as_ref();
1108 (
1109 opts.and_then(|o| o.temperature).or(Some(1.0)),
1110 opts.and_then(|o| o.top_p).or(Some(0.95)),
1111 opts.and_then(|o| o.max_tokens)
1112 .map(|v| v as i32)
1113 .or(Some(65536)),
1114 )
1115 };
1116 let thinking_config = _call_options
1117 .as_ref()
1118 .and_then(|o| o.thinking_config)
1119 .and_then(|tc| {
1120 if tc.enabled {
1121 Some(GeminiThinkingConfig {
1122 include_thoughts: tc.include_thoughts,
1123 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1124 })
1125 } else {
1126 None
1127 }
1128 });
1129
1130 let request = GeminiRequest {
1131 contents: gemini_contents,
1132 system_instruction,
1133 tools: gemini_tools,
1134 generation_config: Some(GeminiGenerationConfig {
1135 max_output_tokens,
1136 temperature,
1137 top_p,
1138 thinking_config,
1139 ..Default::default()
1140 }),
1141 };
1142
1143 let response = tokio::select! {
1144 biased;
1145 () = token.cancelled() => {
1146 debug!(target: "gemini::complete", "Cancellation token triggered before sending request.");
1147 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1148 }
1149 res = self.client.post(&url).json(&request).send() => {
1150 res.map_err(ApiError::Network)?
1151 }
1152 };
1153 let status = response.status();
1154
1155 if status != StatusCode::OK {
1156 let error_text = response.text().await.map_err(ApiError::Network)?;
1157 error!(target: "Gemini API Error Response", "Status: {}, Body: {}", status, error_text);
1158 return Err(match status.as_u16() {
1159 401 | 403 => ApiError::AuthenticationFailed {
1160 provider: self.name().to_string(),
1161 details: error_text,
1162 },
1163 429 => ApiError::RateLimited {
1164 provider: self.name().to_string(),
1165 details: error_text,
1166 },
1167 400 | 404 => {
1168 error!(target: "Gemini API Error Response", "Status: {}, Body: {}, Request: {}", status, error_text, serde_json::to_string_pretty(&request).unwrap_or_else(|_| "Failed to serialize request".to_string()));
1169 ApiError::InvalidRequest {
1170 provider: self.name().to_string(),
1171 details: error_text,
1172 }
1173 } 500..=599 => ApiError::ServerError {
1175 provider: self.name().to_string(),
1176 status_code: status.as_u16(),
1177 details: error_text,
1178 },
1179 _ => ApiError::Unknown {
1180 provider: self.name().to_string(),
1181 details: error_text,
1182 },
1183 });
1184 }
1185
1186 let response_text = response.text().await.map_err(ApiError::Network)?;
1187
1188 match serde_json::from_str::<GeminiResponse>(&response_text) {
1189 Ok(gemini_response) => {
1190 convert_response(gemini_response).map_err(|e| ApiError::ResponseParsingError {
1191 provider: self.name().to_string(),
1192 details: e.to_string(),
1193 })
1194 }
1195 Err(e) => {
1196 error!(target: "Gemini API JSON Parsing Error", "Failed to parse JSON: {}. Response body:\n{}", e, response_text);
1197 Err(ApiError::ResponseParsingError {
1198 provider: self.name().to_string(),
1199 details: format!("Status: {status}, Error: {e}, Body: {response_text}"),
1200 })
1201 }
1202 }
1203 }
1204
1205 async fn stream_complete(
1206 &self,
1207 model_id: &ModelId,
1208 messages: Vec<AppMessage>,
1209 system: Option<SystemContext>,
1210 tools: Option<Vec<ToolSchema>>,
1211 _call_options: Option<ModelParameters>,
1212 token: CancellationToken,
1213 ) -> Result<CompletionStream, ApiError> {
1214 let model_name = &model_id.id;
1215 let url = format!(
1216 "{}/models/{}:streamGenerateContent?alt=sse&key={}",
1217 GEMINI_API_BASE, model_name, self.api_key
1218 );
1219
1220 let gemini_contents = convert_messages(messages);
1221
1222 let system_instruction = system
1223 .and_then(|context| context.render())
1224 .map(|instructions| GeminiSystemInstruction {
1225 parts: vec![GeminiRequestPart::Text { text: instructions }],
1226 });
1227
1228 let gemini_tools = tools.map(convert_tools);
1229
1230 let (temperature, top_p, max_output_tokens) = {
1231 let opts = _call_options.as_ref();
1232 (
1233 opts.and_then(|o| o.temperature).or(Some(1.0)),
1234 opts.and_then(|o| o.top_p).or(Some(0.95)),
1235 opts.and_then(|o| o.max_tokens)
1236 .map(|v| v as i32)
1237 .or(Some(65536)),
1238 )
1239 };
1240 let thinking_config = _call_options
1241 .as_ref()
1242 .and_then(|o| o.thinking_config)
1243 .and_then(|tc| {
1244 if tc.enabled {
1245 Some(GeminiThinkingConfig {
1246 include_thoughts: tc.include_thoughts,
1247 thinking_budget: tc.budget_tokens.map(|v| v as i32),
1248 })
1249 } else {
1250 None
1251 }
1252 });
1253
1254 let request = GeminiRequest {
1255 contents: gemini_contents,
1256 system_instruction,
1257 tools: gemini_tools,
1258 generation_config: Some(GeminiGenerationConfig {
1259 max_output_tokens,
1260 temperature,
1261 top_p,
1262 thinking_config,
1263 ..Default::default()
1264 }),
1265 };
1266
1267 let response = tokio::select! {
1268 biased;
1269 () = token.cancelled() => {
1270 return Err(ApiError::Cancelled{ provider: self.name().to_string()});
1271 }
1272 res = self.client.post(&url).json(&request).send() => {
1273 res.map_err(ApiError::Network)?
1274 }
1275 };
1276
1277 let status = response.status();
1278 if status != StatusCode::OK {
1279 let error_text = response.text().await.map_err(ApiError::Network)?;
1280 error!(target: "gemini::stream", "API error - Status: {}, Body: {}", status, error_text);
1281 return Err(match status.as_u16() {
1282 401 | 403 => ApiError::AuthenticationFailed {
1283 provider: self.name().to_string(),
1284 details: error_text,
1285 },
1286 429 => ApiError::RateLimited {
1287 provider: self.name().to_string(),
1288 details: error_text,
1289 },
1290 400 | 404 => ApiError::InvalidRequest {
1291 provider: self.name().to_string(),
1292 details: error_text,
1293 },
1294 500..=599 => ApiError::ServerError {
1295 provider: self.name().to_string(),
1296 status_code: status.as_u16(),
1297 details: error_text,
1298 },
1299 _ => ApiError::Unknown {
1300 provider: self.name().to_string(),
1301 details: error_text,
1302 },
1303 });
1304 }
1305
1306 let byte_stream = response.bytes_stream();
1307 let sse_stream = parse_sse_stream(byte_stream);
1308
1309 Ok(Box::pin(Self::convert_gemini_stream(sse_stream, token)))
1310 }
1311}
1312
1313impl GeminiClient {
1314 fn convert_gemini_stream(
1315 mut sse_stream: impl futures::Stream<Item = Result<crate::api::sse::SseEvent, SseParseError>>
1316 + Unpin
1317 + Send
1318 + 'static,
1319 token: CancellationToken,
1320 ) -> impl futures::Stream<Item = StreamChunk> + Send + 'static {
1321 async_stream::stream! {
1322 let mut content: Vec<AssistantContent> = Vec::new();
1323 loop {
1324 if token.is_cancelled() {
1325 yield StreamChunk::Error(StreamError::Cancelled);
1326 break;
1327 }
1328
1329 let event_result = tokio::select! {
1330 biased;
1331 () = token.cancelled() => {
1332 yield StreamChunk::Error(StreamError::Cancelled);
1333 break;
1334 }
1335 event = sse_stream.next() => event
1336 };
1337
1338 let Some(event_result) = event_result else {
1339 let content = std::mem::take(&mut content);
1340 yield StreamChunk::MessageComplete(CompletionResponse { content });
1341 break;
1342 };
1343
1344 let event = match event_result {
1345 Ok(e) => e,
1346 Err(e) => {
1347 yield StreamChunk::Error(StreamError::SseParse(e));
1348 break;
1349 }
1350 };
1351
1352 let chunk: GeminiResponse = match serde_json::from_str(&event.data) {
1353 Ok(c) => c,
1354 Err(e) => {
1355 debug!(target: "gemini::stream", "Failed to parse chunk: {} data: {}", e, event.data);
1356 continue;
1357 }
1358 };
1359
1360 if let Some(candidates) = chunk.candidates {
1361 for candidate in candidates {
1362 for part in candidate.content.parts {
1363 let GeminiResponsePart {
1364 thought,
1365 thought_signature,
1366 data,
1367 } = part;
1368 let thought_signature =
1369 thought_signature.map(ThoughtSignature::new);
1370
1371 if thought {
1372 if let GeminiResponsePartData::Text { text } = data {
1373 match content.last_mut() {
1374 Some(AssistantContent::Thought {
1375 thought: ThoughtContent::Simple { text: buf },
1376 }) => buf.push_str(&text),
1377 _ => {
1378 content.push(AssistantContent::Thought {
1379 thought: ThoughtContent::Simple { text: text.clone() },
1380 });
1381 }
1382 }
1383 yield StreamChunk::ThinkingDelta(text);
1384 }
1385 } else {
1386 match data {
1387 GeminiResponsePartData::Text { text } => {
1388 match content.last_mut() {
1389 Some(AssistantContent::Text { text: buf }) => buf.push_str(&text),
1390 _ => {
1391 content.push(AssistantContent::Text { text: text.clone() });
1392 }
1393 }
1394 yield StreamChunk::TextDelta(text);
1395 }
1396 GeminiResponsePartData::FunctionCall { function_call } => {
1397 let id = uuid::Uuid::new_v4().to_string();
1398 content.push(AssistantContent::ToolCall {
1399 tool_call: steer_tools::ToolCall {
1400 id: id.clone(),
1401 name: function_call.name.clone(),
1402 parameters: function_call.args.clone(),
1403 },
1404 thought_signature,
1405 });
1406 yield StreamChunk::ToolUseStart {
1407 id: id.clone(),
1408 name: function_call.name,
1409 };
1410 yield StreamChunk::ToolUseInputDelta {
1411 id,
1412 delta: function_call.args.to_string(),
1413 };
1414 }
1415 _ => {}
1416 }
1417 }
1418 }
1419 }
1420 }
1421 }
1422 }
1423 }
1424}
1425
1426#[cfg(test)]
1427mod tests {
1428 use super::*;
1429 use serde_json::json;
1430
1431 #[test]
1432 fn test_simplify_property_schema_removes_additional_properties() {
1433 let property_value = json!({
1434 "type": "object",
1435 "properties": {
1436 "name": {"type": "string"}
1437 },
1438 "additionalProperties": false
1439 });
1440
1441 let expected = json!({
1442 "type": "object",
1443 "properties": {
1444 "name": {"type": "string"}
1445 }
1446 });
1447
1448 let result = simplify_property_schema("testProp", "testTool", &property_value);
1449 assert_eq!(result, expected);
1450 }
1451
1452 #[test]
1453 fn test_simplify_property_schema_removes_unsupported_string_formats() {
1454 let property_value = json!({
1455 "type": "string",
1456 "format": "uri",
1457 "minLength": 1,
1458 "maxLength": 100,
1459 "pattern": "^https://"
1460 });
1461
1462 let expected = json!({
1463 "type": "string"
1464 });
1465
1466 let result = simplify_property_schema("urlProp", "testTool", &property_value);
1467 assert_eq!(result, expected);
1468 }
1469
1470 #[test]
1471 fn test_simplify_property_schema_keeps_supported_string_formats() {
1472 let property_value = json!({
1473 "type": "string",
1474 "format": "date-time"
1475 });
1476
1477 let expected = json!({
1478 "type": "string",
1479 "format": "date-time"
1480 });
1481
1482 let result = simplify_property_schema("dateProp", "testTool", &property_value);
1483 assert_eq!(result, expected);
1484 }
1485
1486 #[test]
1487 fn test_simplify_property_schema_handles_array_types() {
1488 let property_value = json!({
1489 "type": ["string", "null"],
1490 "format": "email"
1491 });
1492
1493 let expected = json!({
1494 "type": "string"
1495 });
1496
1497 let result = simplify_property_schema("emailProp", "testTool", &property_value);
1498 assert_eq!(result, expected);
1499 }
1500
1501 #[test]
1502 fn test_simplify_property_schema_recursively_handles_array_items() {
1503 let property_value = json!({
1504 "type": "array",
1505 "items": {
1506 "type": "object",
1507 "properties": {
1508 "url": {
1509 "type": "string",
1510 "format": "uri"
1511 }
1512 },
1513 "additionalProperties": false
1514 }
1515 });
1516
1517 let expected = json!({
1518 "type": "array",
1519 "items": {
1520 "type": "object",
1521 "properties": {
1522 "url": {
1523 "type": "string"
1524 }
1525 }
1526 }
1527 });
1528
1529 let result = simplify_property_schema("linksProp", "testTool", &property_value);
1530 assert_eq!(result, expected);
1531 }
1532
1533 #[test]
1534 fn test_simplify_property_schema_recursively_handles_nested_objects() {
1535 let property_value = json!({
1536 "type": "object",
1537 "properties": {
1538 "nested": {
1539 "type": "object",
1540 "properties": {
1541 "field": {
1542 "type": "string",
1543 "format": "hostname"
1544 }
1545 },
1546 "additionalProperties": true
1547 }
1548 },
1549 "additionalProperties": false
1550 });
1551
1552 let expected = json!({
1553 "type": "object",
1554 "properties": {
1555 "nested": {
1556 "type": "object",
1557 "properties": {
1558 "field": {
1559 "type": "string"
1560 }
1561 }
1562 }
1563 }
1564 });
1565
1566 let result = simplify_property_schema("complexProp", "testTool", &property_value);
1567 assert_eq!(result, expected);
1568 }
1569
1570 #[test]
1571 fn test_simplify_property_schema_fixes_uint64_format() {
1572 let property_value = json!({
1573 "type": "integer",
1574 "format": "uint64"
1575 });
1576
1577 let expected = json!({
1578 "type": "integer",
1579 "format": "int64"
1580 });
1581
1582 let result = simplify_property_schema("idProp", "testTool", &property_value);
1583 assert_eq!(result, expected);
1584 }
1585
1586 #[test]
1587 fn test_convert_tools_integration() {
1588 use steer_tools::{InputSchema, ToolSchema};
1589
1590 let tool = ToolSchema {
1591 name: "create_issue".to_string(),
1592 display_name: "Create Issue".to_string(),
1593 description: "Create an issue".to_string(),
1594 input_schema: InputSchema::object(
1595 {
1596 let mut props = serde_json::Map::new();
1597 props.insert(
1598 "title".to_string(),
1599 json!({
1600 "type": "string",
1601 "minLength": 1
1602 }),
1603 );
1604 props.insert(
1605 "links".to_string(),
1606 json!({
1607 "type": "array",
1608 "items": {
1609 "type": "object",
1610 "properties": {
1611 "url": {
1612 "type": "string",
1613 "format": "uri"
1614 }
1615 },
1616 "additionalProperties": false
1617 }
1618 }),
1619 );
1620 props
1621 },
1622 vec!["title".to_string()],
1623 ),
1624 };
1625
1626 let expected_tools = vec![GeminiTool {
1627 function_declarations: vec![GeminiFunctionDeclaration {
1628 name: "create_issue".to_string(),
1629 description: "Create an issue".to_string(),
1630 parameters: GeminiParameterSchema {
1631 schema_type: "object".to_string(),
1632 properties: {
1633 let mut props = serde_json::Map::new();
1634 props.insert(
1635 "title".to_string(),
1636 json!({
1637 "type": "string"
1638 }),
1639 );
1640 props.insert(
1641 "links".to_string(),
1642 json!({
1643 "type": "array",
1644 "items": {
1645 "type": "object",
1646 "properties": {
1647 "url": {
1648 "type": "string"
1649 }
1650 }
1651 }
1652 }),
1653 );
1654 props
1655 },
1656 required: vec!["title".to_string()],
1657 },
1658 }],
1659 }];
1660
1661 let result = convert_tools(vec![tool]);
1662 assert_eq!(result, expected_tools);
1663 }
1664
1665 #[tokio::test]
1666 async fn test_convert_gemini_stream_text_deltas() {
1667 use crate::api::provider::StreamChunk;
1668 use crate::api::sse::SseEvent;
1669 use futures::StreamExt;
1670 use futures::stream;
1671 use std::pin::pin;
1672 use tokio_util::sync::CancellationToken;
1673
1674 let events = vec![
1675 Ok(SseEvent {
1676 event_type: None,
1677 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1678 .to_string(),
1679 id: None,
1680 }),
1681 Ok(SseEvent {
1682 event_type: None,
1683 data:
1684 r#"{"candidates":[{"content":{"role":"model","parts":[{"text":" world"}]}}]}"#
1685 .to_string(),
1686 id: None,
1687 }),
1688 ];
1689
1690 let sse_stream = stream::iter(events);
1691 let token = CancellationToken::new();
1692 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1693
1694 let first_delta = stream.next().await.unwrap();
1695 assert!(matches!(first_delta, StreamChunk::TextDelta(ref t) if t == "Hello"));
1696
1697 let second_delta = stream.next().await.unwrap();
1698 assert!(matches!(second_delta, StreamChunk::TextDelta(ref t) if t == " world"));
1699
1700 let complete = stream.next().await.unwrap();
1701 assert!(matches!(complete, StreamChunk::MessageComplete(_)));
1702 }
1703
1704 #[tokio::test]
1705 async fn test_convert_gemini_stream_with_thinking() {
1706 use crate::api::provider::StreamChunk;
1707 use crate::api::sse::SseEvent;
1708 use futures::StreamExt;
1709 use futures::stream;
1710 use std::pin::pin;
1711 use tokio_util::sync::CancellationToken;
1712
1713 let events = vec![
1714 Ok(SseEvent {
1715 event_type: None,
1716 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Let me think..."}]}}]}"#.to_string(),
1717 id: None,
1718 }),
1719 Ok(SseEvent {
1720 event_type: None,
1721 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"The answer"}]}}]}"#.to_string(),
1722 id: None,
1723 }),
1724 ];
1725
1726 let sse_stream = stream::iter(events);
1727 let token = CancellationToken::new();
1728 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1729
1730 let thinking_delta = stream.next().await.unwrap();
1731 assert!(
1732 matches!(thinking_delta, StreamChunk::ThinkingDelta(ref t) if t == "Let me think...")
1733 );
1734
1735 let text_delta = stream.next().await.unwrap();
1736 assert!(matches!(text_delta, StreamChunk::TextDelta(ref t) if t == "The answer"));
1737
1738 let complete = stream.next().await.unwrap();
1739 if let StreamChunk::MessageComplete(response) = complete {
1740 assert_eq!(response.content.len(), 2);
1741 assert!(matches!(
1742 &response.content[0],
1743 AssistantContent::Thought { .. }
1744 ));
1745 assert!(matches!(
1746 &response.content[1],
1747 AssistantContent::Text { .. }
1748 ));
1749 } else {
1750 panic!("Expected MessageComplete");
1751 }
1752 }
1753
1754 #[tokio::test]
1755 async fn test_convert_gemini_stream_with_function_call() {
1756 use crate::api::provider::StreamChunk;
1757 use crate::api::sse::SseEvent;
1758 use futures::StreamExt;
1759 use futures::stream;
1760 use std::pin::pin;
1761 use tokio_util::sync::CancellationToken;
1762
1763 let events = vec![
1764 Ok(SseEvent {
1765 event_type: None,
1766 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"city":"NYC"}},"thoughtSignature":"sig_123"}]}}]}"#.to_string(),
1767 id: None,
1768 }),
1769 ];
1770
1771 let sse_stream = stream::iter(events);
1772 let token = CancellationToken::new();
1773 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1774
1775 let tool_start = stream.next().await.unwrap();
1776 assert!(
1777 matches!(tool_start, StreamChunk::ToolUseStart { ref name, .. } if name == "get_weather")
1778 );
1779
1780 let tool_input = stream.next().await.unwrap();
1781 assert!(matches!(tool_input, StreamChunk::ToolUseInputDelta { .. }));
1782
1783 let complete = stream.next().await.unwrap();
1784 if let StreamChunk::MessageComplete(response) = complete {
1785 assert_eq!(response.content.len(), 1);
1786 if let AssistantContent::ToolCall {
1787 tool_call,
1788 thought_signature,
1789 } = &response.content[0]
1790 {
1791 assert_eq!(tool_call.name, "get_weather");
1792 assert_eq!(
1793 thought_signature.as_ref().map(|sig| sig.as_str()),
1794 Some("sig_123")
1795 );
1796 } else {
1797 panic!("Expected ToolCall");
1798 }
1799 } else {
1800 panic!("Expected MessageComplete");
1801 }
1802 }
1803
1804 #[tokio::test]
1805 async fn test_convert_gemini_stream_cancellation() {
1806 use crate::api::error::StreamError;
1807 use crate::api::provider::StreamChunk;
1808 use crate::api::sse::SseEvent;
1809 use futures::StreamExt;
1810 use futures::stream;
1811 use std::pin::pin;
1812 use tokio_util::sync::CancellationToken;
1813
1814 let events = vec![Ok(SseEvent {
1815 event_type: None,
1816 data: r#"{"candidates":[{"content":{"role":"model","parts":[{"text":"Hello"}]}}]}"#
1817 .to_string(),
1818 id: None,
1819 })];
1820
1821 let sse_stream = stream::iter(events);
1822 let token = CancellationToken::new();
1823 token.cancel();
1824
1825 let mut stream = pin!(GeminiClient::convert_gemini_stream(sse_stream, token));
1826
1827 let cancelled = stream.next().await.unwrap();
1828 assert!(matches!(
1829 cancelled,
1830 StreamChunk::Error(StreamError::Cancelled)
1831 ));
1832 }
1833
1834 #[tokio::test]
1835 #[ignore = "Requires GOOGLE_API_KEY environment variable"]
1836 async fn test_stream_complete_real_api() {
1837 use crate::api::Provider;
1838 use crate::api::provider::StreamChunk;
1839 use crate::app::conversation::{Message, MessageData, UserContent};
1840 use futures::StreamExt;
1841 use tokio_util::sync::CancellationToken;
1842
1843 dotenvy::dotenv().ok();
1844 let api_key = std::env::var("GOOGLE_API_KEY").expect("GOOGLE_API_KEY must be set");
1845 let client = GeminiClient::new(api_key);
1846
1847 let message = Message {
1848 data: MessageData::User {
1849 content: vec![UserContent::Text {
1850 text: "Say exactly: Hello".to_string(),
1851 }],
1852 },
1853 timestamp: chrono::Utc::now().timestamp_millis() as u64,
1854 id: "test-msg".to_string(),
1855 parent_message_id: None,
1856 };
1857
1858 let model_id = ModelId::new(
1859 crate::config::provider::google(),
1860 "gemini-2.5-flash-preview-04-17",
1861 );
1862 let token = CancellationToken::new();
1863
1864 let mut stream = client
1865 .stream_complete(&model_id, vec![message], None, None, None, token)
1866 .await
1867 .expect("stream_complete should succeed");
1868
1869 let mut got_text_delta = false;
1870 let mut got_message_complete = false;
1871 let mut accumulated_text = String::new();
1872
1873 while let Some(chunk) = stream.next().await {
1874 match chunk {
1875 StreamChunk::TextDelta(text) => {
1876 got_text_delta = true;
1877 accumulated_text.push_str(&text);
1878 }
1879 StreamChunk::MessageComplete(response) => {
1880 got_message_complete = true;
1881 assert!(!response.content.is_empty());
1882 }
1883 StreamChunk::Error(e) => panic!("Unexpected error: {e:?}"),
1884 _ => {}
1885 }
1886 }
1887
1888 assert!(got_text_delta, "Should receive at least one TextDelta");
1889 assert!(
1890 got_message_complete,
1891 "Should receive MessageComplete at the end"
1892 );
1893 assert!(
1894 accumulated_text.to_lowercase().contains("hello"),
1895 "Response should contain 'hello', got: {accumulated_text}"
1896 );
1897 }
1898}