1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::Engine;
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11 ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12 UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19 if error.is_timeout() {
20 return ModelError::Timeout;
21 }
22 if error.is_connect() {
23 return ModelError::Transport(format!("{label} connect error: {error}"));
24 }
25 ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29 const LIMIT: usize = 512;
30 if body.len() <= LIMIT {
31 body.to_string()
32 } else {
33 format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34 }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38 match id {
39 Some(value) if !value.trim().is_empty() => value,
40 _ => format!("call_{}", Uuid::new_v4().simple()),
41 }
42}
43
44fn gemini_response_object(value: &Value) -> Value {
45 match value {
46 Value::Object(_) => value.clone(),
47 _ => {
48 let mut wrapped = Map::new();
49 wrapped.insert("return_value".to_string(), value.clone());
50 Value::Object(wrapped)
51 }
52 }
53}
54
55fn is_null_schema(value: &Value) -> bool {
56 matches!(
57 value,
58 Value::Object(map) if matches!(map.get("type"), Some(Value::String(t)) if t == "null")
59 )
60}
61
62fn sanitize_gemini_schema(value: &Value) -> Value {
63 match value {
64 Value::Object(map) => {
65 if let Some(variants) = map.get("anyOf").and_then(|val| val.as_array()) {
66 let mut cleaned = variants
67 .iter()
68 .filter(|variant| !is_null_schema(variant))
69 .map(sanitize_gemini_schema)
70 .collect::<Vec<_>>();
71 if cleaned.len() == 1 {
72 return cleaned.pop().unwrap_or(Value::Null);
73 }
74 }
75 if let Some(variants) = map.get("oneOf").and_then(|val| val.as_array()) {
76 let mut cleaned = variants
77 .iter()
78 .filter(|variant| !is_null_schema(variant))
79 .map(sanitize_gemini_schema)
80 .collect::<Vec<_>>();
81 if cleaned.len() == 1 {
82 return cleaned.pop().unwrap_or(Value::Null);
83 }
84 }
85
86 let mut out = Map::new();
87 for (key, val) in map {
88 if matches!(
89 key.as_str(),
90 "additionalProperties" | "$schema" | "$id" | "title"
91 ) {
92 continue;
93 }
94 if key == "type"
95 && let Value::Array(types) = val
96 {
97 if let Some(first) = types
98 .iter()
99 .find(|item| !matches!(item, Value::String(t) if t == "null"))
100 {
101 out.insert(key.clone(), first.clone());
102 }
103 continue;
104 }
105 out.insert(key.clone(), sanitize_gemini_schema(val));
106 }
107 Value::Object(out)
108 }
109 Value::Array(items) => Value::Array(items.iter().map(sanitize_gemini_schema).collect()),
110 _ => value.clone(),
111 }
112}
113
114fn infer_media_type_from_url(url: &str) -> Option<String> {
115 let path = url.split('?').next()?;
116 let ext = path.rsplit('.').next()?.to_lowercase();
117 let media_type = match ext.as_str() {
118 "png" => "image/png",
119 "jpg" | "jpeg" => "image/jpeg",
120 "gif" => "image/gif",
121 "webp" => "image/webp",
122 "pdf" => "application/pdf",
123 "txt" => "text/plain",
124 "md" | "markdown" => "text/markdown",
125 "csv" => "text/csv",
126 "json" => "application/json",
127 "mp3" => "audio/mpeg",
128 "wav" => "audio/wav",
129 "ogg" | "oga" => "audio/ogg",
130 "flac" => "audio/flac",
131 "m4a" | "aac" => "audio/aac",
132 "mp4" => "video/mp4",
133 "mov" => "video/quicktime",
134 "webm" => "video/webm",
135 "mkv" => "video/x-matroska",
136 _ => return None,
137 };
138 Some(media_type.to_string())
139}
140
141fn file_data_part(url: &str, media_type: &Option<String>) -> Value {
142 let mut file_data = Map::new();
143 file_data.insert("fileUri".to_string(), Value::String(url.to_string()));
144 let inferred = media_type
145 .clone()
146 .or_else(|| infer_media_type_from_url(url));
147 if let Some(media_type) = inferred {
148 file_data.insert("mimeType".to_string(), Value::String(media_type.clone()));
149 }
150 let mut wrapper = Map::new();
151 wrapper.insert("fileData".to_string(), Value::Object(file_data));
152 Value::Object(wrapper)
153}
154
155#[derive(Clone, Debug)]
156pub struct GeminiProvider {
157 api_key: String,
158 base_url: Url,
159}
160
161impl GeminiProvider {
162 pub fn new(
163 api_key: impl Into<String>,
164 base_url: impl AsRef<str>,
165 ) -> Result<Self, ProviderError> {
166 let url = Url::parse(base_url.as_ref())
167 .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
168 Ok(Self {
169 api_key: api_key.into(),
170 base_url: url,
171 })
172 }
173
174 pub fn from_env() -> Result<Self, ProviderError> {
175 let api_key = std::env::var("GEMINI_API_KEY")
176 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
177 .map_err(|_| ProviderError::MissingApiKey("gemini".to_string()))?;
178 Self::new(api_key, "https://generativelanguage.googleapis.com")
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::messages::{
186 BinaryContent, ImageUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
187 ModelResponsePart, ToolCallPart, ToolReturnPart,
188 };
189 use base64::engine::general_purpose::STANDARD;
190 use serde_json::{Value, json};
191 use std::path::PathBuf;
192
193 fn fixture_bytes(name: &str) -> Vec<u8> {
194 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195 .join("tests")
196 .join("fixtures")
197 .join(name);
198 std::fs::read(path).expect("fixture read")
199 }
200
201 #[test]
202 fn convert_user_content_handles_inline_and_file_data() {
203 let pdf_bytes = fixture_bytes("fixture.pdf");
204 let audio_bytes = fixture_bytes("fixture.m4a");
205
206 let content = vec![
207 UserContent::Binary(BinaryContent {
208 data: pdf_bytes.clone(),
209 media_type: "application/pdf".to_string(),
210 }),
211 UserContent::Binary(BinaryContent {
212 data: audio_bytes.clone(),
213 media_type: "audio/aac".to_string(),
214 }),
215 UserContent::Image(ImageUrl {
216 url: "https://example.com/fixture.jpg".to_string(),
217 media_type: None,
218 }),
219 ];
220
221 let parts = convert_user_content(&content);
222 assert_eq!(parts.len(), 3);
223
224 let pdf = &parts[0];
225 let pdf_inline = pdf.get("inlineData").expect("pdf inline");
226 assert_eq!(
227 pdf_inline.get("mimeType"),
228 Some(&Value::String("application/pdf".to_string()))
229 );
230 assert_eq!(
231 pdf_inline.get("data"),
232 Some(&Value::String(STANDARD.encode(&pdf_bytes)))
233 );
234
235 let audio = &parts[1];
236 let audio_inline = audio.get("inlineData").expect("audio inline");
237 assert_eq!(
238 audio_inline.get("mimeType"),
239 Some(&Value::String("audio/aac".to_string()))
240 );
241 assert_eq!(
242 audio_inline.get("data"),
243 Some(&Value::String(STANDARD.encode(&audio_bytes)))
244 );
245
246 let image = &parts[2];
247 let file_data = image.get("fileData").expect("file data");
248 assert_eq!(
249 file_data.get("fileUri"),
250 Some(&Value::String(
251 "https://example.com/fixture.jpg".to_string()
252 ))
253 );
254 assert_eq!(
255 file_data.get("mimeType"),
256 Some(&Value::String("image/jpeg".to_string()))
257 );
258 }
259
260 #[test]
261 fn split_system_replays_tool_calls() {
262 let messages = vec![
263 ModelMessage::Response(ModelResponse {
264 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
265 id: "call-1".to_string(),
266 name: "get_data".to_string(),
267 arguments: json!({"a": 1}),
268 })],
269 usage: None,
270 model_name: None,
271 finish_reason: None,
272 }),
273 ModelMessage::Request(ModelRequest {
274 parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
275 tool_name: "get_data".to_string(),
276 tool_call_id: "call-1".to_string(),
277 content: json!({"ok": true}),
278 })],
279 instructions: None,
280 }),
281 ];
282
283 let (_system, contents) = GeminiModel::split_system(&messages);
284 assert_eq!(contents.len(), 2);
285
286 let model_msg = contents[0].as_object().expect("model message");
287 assert_eq!(
288 model_msg.get("role"),
289 Some(&Value::String("model".to_string()))
290 );
291 let model_parts = model_msg
292 .get("parts")
293 .and_then(|value| value.as_array())
294 .expect("model parts");
295 let function_call = model_parts
296 .iter()
297 .find_map(|part| part.get("functionCall"))
298 .expect("functionCall");
299 assert_eq!(
300 function_call.get("name"),
301 Some(&Value::String("get_data".to_string()))
302 );
303 assert_eq!(function_call.get("args"), Some(&json!({"a": 1})));
304
305 let user_msg = contents[1].as_object().expect("user message");
306 assert_eq!(
307 user_msg.get("role"),
308 Some(&Value::String("user".to_string()))
309 );
310 let user_parts = user_msg
311 .get("parts")
312 .and_then(|value| value.as_array())
313 .expect("user parts");
314 let function_response = user_parts
315 .iter()
316 .find_map(|part| part.get("functionResponse"))
317 .expect("functionResponse");
318 assert_eq!(
319 function_response.get("name"),
320 Some(&Value::String("get_data".to_string()))
321 );
322 assert_eq!(
323 function_response.get("response"),
324 Some(&json!({"ok": true}))
325 );
326 }
327
328 #[test]
329 fn helper_functions_cover_schema_and_media() {
330 let wrapped = gemini_response_object(&json!("ok"));
331 assert_eq!(
332 wrapped.get("return_value").and_then(|value| value.as_str()),
333 Some("ok")
334 );
335
336 let schema = json!({
337 "anyOf": [
338 { "type": "null" },
339 { "type": "string" }
340 ],
341 "title": "Example",
342 "additionalProperties": false
343 });
344 let sanitized = sanitize_gemini_schema(&schema);
345 assert_eq!(
346 sanitized.get("type"),
347 Some(&Value::String("string".to_string()))
348 );
349 assert!(sanitized.get("title").is_none());
350
351 assert_eq!(
352 infer_media_type_from_url("https://example.com/file.pdf"),
353 Some("application/pdf".to_string())
354 );
355 assert_eq!(
356 infer_media_type_from_url("https://example.com/file.unknown"),
357 None
358 );
359
360 let part = file_data_part("https://example.com/file.txt", &None);
361 let file_data = part.get("fileData").expect("file data");
362 assert_eq!(
363 file_data.get("mimeType"),
364 Some(&Value::String("text/plain".to_string()))
365 );
366 }
367
368 #[test]
369 fn helper_functions_cover_ids_and_truncation() {
370 let id = normalize_tool_call_id(Some("".to_string()));
371 assert!(id.starts_with("call_"));
372
373 let truncated = truncate_error_body(&"a".repeat(600));
374 assert!(truncated.contains("bytes"));
375 }
376
377 #[test]
378 fn sanitize_gemini_schema_removes_null_type_array() {
379 let schema = json!({
380 "type": ["null", "object"],
381 "properties": {"a": {"type": "string"}},
382 "$schema": "http://json-schema.org/draft-07/schema#"
383 });
384 let sanitized = sanitize_gemini_schema(&schema);
385 assert_eq!(
386 sanitized.get("type"),
387 Some(&Value::String("object".to_string()))
388 );
389 assert!(sanitized.get("$schema").is_none());
390 }
391}
392
393impl Provider for GeminiProvider {
394 fn name(&self) -> &str {
395 "gemini"
396 }
397
398 fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
399 Arc::new(GeminiModel::new(
400 model,
401 self.api_key.clone(),
402 self.base_url.clone(),
403 settings,
404 ))
405 }
406}
407
408#[derive(Clone, Debug)]
409pub struct GeminiModel {
410 model: String,
411 api_key: String,
412 base_url: Url,
413 client: Client,
414 default_settings: Option<ModelSettings>,
415}
416
417impl GeminiModel {
418 pub fn new(
419 model: impl Into<String>,
420 api_key: String,
421 base_url: Url,
422 settings: Option<ModelSettings>,
423 ) -> Self {
424 let mut model = model.into();
425 if !model.starts_with("models/") {
426 model = format!("models/{model}");
427 }
428 Self {
429 model,
430 api_key,
431 base_url,
432 client: Client::new(),
433 default_settings: settings,
434 }
435 }
436
437 fn endpoint(&self) -> Result<Url, ModelError> {
438 let path = format!("v1beta/{}:generateContent", self.model);
439 let mut url = self
440 .base_url
441 .join(&path)
442 .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))?;
443 url.query_pairs_mut().append_pair("key", &self.api_key);
444 Ok(url)
445 }
446
447 fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
448 let mut system_parts = Vec::new();
449 let mut contents = Vec::new();
450
451 for message in messages {
452 match message {
453 ModelMessage::Request(req) => {
454 if let Some(instructions) = req
455 .instructions
456 .as_ref()
457 .filter(|value| !value.trim().is_empty())
458 {
459 system_parts.push(instructions.to_string());
460 }
461 for part in &req.parts {
462 match part {
463 ModelRequestPart::SystemPrompt(prompt) => {
464 system_parts.push(prompt.content.clone());
465 }
466 ModelRequestPart::UserPrompt(prompt) => contents.push(json!({
467 "role": "user",
468 "parts": convert_user_content(&prompt.content)
469 })),
470 ModelRequestPart::ToolReturn(tool_return) => contents.push(json!({
471 "role": "user",
472 "parts": [{
473 "functionResponse": {
474 "name": tool_return.tool_name,
475 "response": gemini_response_object(&tool_return.content),
476 }
477 }]
478 })),
479 ModelRequestPart::RetryPrompt(retry) => {
480 let parts = if let Some(tool_name) = &retry.tool_name {
481 vec![json!({
482 "functionResponse": {
483 "name": tool_name,
484 "response": {"call_error": retry.content}
485 }
486 })]
487 } else {
488 vec![json!({"text": retry.content})]
489 };
490 contents.push(json!({
491 "role": "user",
492 "parts": parts
493 }));
494 }
495 }
496 }
497 }
498 ModelMessage::Response(res) => {
499 let mut parts = Vec::new();
500 if let Some(text) = res.text() {
501 parts.push(json!({"text": text}));
502 }
503 for call in res.tool_calls() {
504 parts.push(json!({
505 "functionCall": {
506 "name": call.name,
507 "args": call.arguments,
508 }
509 }));
510 }
511
512 if !parts.is_empty() {
513 contents.push(json!({
514 "role": "model",
515 "parts": parts
516 }));
517 }
518 }
519 }
520 }
521
522 let system = if system_parts.is_empty() {
523 None
524 } else {
525 Some(system_parts.join("\n\n"))
526 };
527
528 (system, contents)
529 }
530}
531
532fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
533 let mut parts = Vec::new();
534 for item in content {
535 match item {
536 UserContent::Text(text) => parts.push(json!({"text": text})),
537 UserContent::Image(image) => parts.push(file_data_part(&image.url, &image.media_type)),
538 UserContent::Video(video) => parts.push(file_data_part(&video.url, &video.media_type)),
539 UserContent::Audio(audio) => parts.push(file_data_part(&audio.url, &audio.media_type)),
540 UserContent::Document(doc) => parts.push(file_data_part(&doc.url, &doc.media_type)),
541 UserContent::Binary(binary) => parts.push(json!({
542 "inlineData": {
543 "mimeType": binary.media_type,
544 "data": base64::engine::general_purpose::STANDARD.encode(&binary.data)
545 }
546 })),
547 }
548 }
549 parts
550}
551
552#[async_trait]
553impl Model for GeminiModel {
554 fn name(&self) -> &str {
555 &self.model
556 }
557
558 async fn request(
559 &self,
560 messages: &[ModelMessage],
561 settings: Option<&ModelSettings>,
562 params: &ModelRequestParameters,
563 ) -> Result<ModelResponse, ModelError> {
564 tracing::debug!(
565 model = %self.model,
566 tool_count = params.function_tools.len(),
567 output_schema = params.output_schema.is_some(),
568 "Gemini request"
569 );
570 let (system, contents) = Self::split_system(messages);
571 let mut body = Map::new();
572 body.insert("contents".to_string(), Value::Array(contents));
573 if let Some(system) = system {
574 body.insert(
575 "systemInstruction".to_string(),
576 json!({"parts": [{"text": system}]}),
577 );
578 }
579
580 if !params.function_tools.is_empty() {
581 let tools = params
582 .function_tools
583 .iter()
584 .map(|tool| {
585 let schema = sanitize_gemini_schema(&tool.parameters_json_schema);
586 json!({
587 "name": tool.name,
588 "description": tool.description,
589 "parameters": schema,
590 })
591 })
592 .collect::<Vec<_>>();
593 body.insert(
594 "tools".to_string(),
595 json!([{ "functionDeclarations": tools }]),
596 );
597 body.insert(
598 "toolConfig".to_string(),
599 json!({"functionCallingConfig": {"mode": "AUTO"}}),
600 );
601 }
602
603 if params.output_mode == OutputMode::JsonSchema
604 && let Some(schema) = params.output_schema.clone()
605 {
606 let schema = sanitize_gemini_schema(&schema);
607 body.insert(
608 "generationConfig".to_string(),
609 json!({
610 "responseMimeType": "application/json",
611 "responseSchema": schema
612 }),
613 );
614 }
615
616 if let Some(settings) = &self.default_settings {
617 for (key, value) in settings {
618 body.insert(key.clone(), value.clone());
619 }
620 }
621
622 if let Some(settings) = settings {
623 for (key, value) in settings {
624 body.insert(key.clone(), value.clone());
625 }
626 }
627
628 let response = self
629 .client
630 .post(self.endpoint()?)
631 .json(&Value::Object(body))
632 .send()
633 .await
634 .map_err(|e| map_reqwest_error("Gemini", e))?;
635
636 let status = response.status();
637 if !status.is_success() {
638 let body = response.text().await.unwrap_or_default();
639 tracing::error!(
640 status = status.as_u16(),
641 model = %self.model,
642 body = %truncate_error_body(&body),
643 "Gemini request failed"
644 );
645 return Err(ModelError::HttpStatus {
646 status: status.as_u16(),
647 });
648 }
649
650 let body: GeminiResponse = response.json().await.map_err(|e| {
651 tracing::error!(
652 error = %e,
653 model = %self.model,
654 "Gemini response parse failed"
655 );
656 ModelError::Provider(format!("Gemini response parse failed: {e}"))
657 })?;
658
659 let candidate = body.candidates.into_iter().next().ok_or_else(|| {
660 tracing::error!(model = %self.model, "Gemini response missing candidates");
661 ModelError::Provider("Gemini response missing candidates".to_string())
662 })?;
663
664 let mut parts = Vec::new();
665 if let Some(content) = candidate.content {
666 for part in content.parts {
667 if let Some(text) = part.text {
668 parts.push(ModelResponsePart::Text(TextPart { content: text }));
669 }
670 if let Some(call) = part.function_call {
671 parts.push(ModelResponsePart::ToolCall(ToolCallPart {
672 id: normalize_tool_call_id(call.id),
673 name: call.name.unwrap_or_else(|| "tool".to_string()),
674 arguments: call.args.unwrap_or_else(|| Value::Object(Map::new())),
675 }));
676 }
677 }
678 }
679
680 let usage = body.usage_metadata.map(|usage| RequestUsage {
681 input_tokens: usage.prompt_token_count.unwrap_or(0),
682 output_tokens: usage.candidates_token_count.unwrap_or(0),
683 ..Default::default()
684 });
685
686 Ok(ModelResponse {
687 parts,
688 usage,
689 model_name: Some(self.model.clone()),
690 finish_reason: candidate.finish_reason,
691 })
692 }
693}
694
695#[derive(Debug, Deserialize)]
696struct GeminiResponse {
697 candidates: Vec<GeminiCandidate>,
698 #[serde(rename = "usageMetadata")]
699 usage_metadata: Option<GeminiUsage>,
700}
701
702#[derive(Debug, Deserialize)]
703struct GeminiCandidate {
704 content: Option<GeminiContent>,
705 #[serde(rename = "finishReason")]
706 finish_reason: Option<String>,
707}
708
709#[derive(Debug, Deserialize)]
710struct GeminiContent {
711 parts: Vec<GeminiPart>,
712}
713
714#[derive(Debug, Deserialize)]
715struct GeminiPart {
716 text: Option<String>,
717 #[serde(rename = "functionCall")]
718 function_call: Option<GeminiFunctionCall>,
719}
720
721#[derive(Debug, Deserialize)]
722struct GeminiFunctionCall {
723 id: Option<String>,
724 name: Option<String>,
725 args: Option<Value>,
726}
727
728#[derive(Debug, Deserialize)]
729struct GeminiUsage {
730 #[serde(rename = "promptTokenCount")]
731 prompt_token_count: Option<u64>,
732 #[serde(rename = "candidatesTokenCount")]
733 candidates_token_count: Option<u64>,
734}