Skip to main content

rustic_ai/providers/
gemini.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::Engine;
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11    ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12    UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19    if error.is_timeout() {
20        return ModelError::Timeout;
21    }
22    if error.is_connect() {
23        return ModelError::Transport(format!("{label} connect error: {error}"));
24    }
25    ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29    const LIMIT: usize = 512;
30    if body.len() <= LIMIT {
31        body.to_string()
32    } else {
33        format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34    }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38    match id {
39        Some(value) if !value.trim().is_empty() => value,
40        _ => format!("call_{}", Uuid::new_v4().simple()),
41    }
42}
43
44fn gemini_response_object(value: &Value) -> Value {
45    match value {
46        Value::Object(_) => value.clone(),
47        _ => {
48            let mut wrapped = Map::new();
49            wrapped.insert("return_value".to_string(), value.clone());
50            Value::Object(wrapped)
51        }
52    }
53}
54
55fn is_null_schema(value: &Value) -> bool {
56    matches!(
57        value,
58        Value::Object(map) if matches!(map.get("type"), Some(Value::String(t)) if t == "null")
59    )
60}
61
62fn sanitize_gemini_schema(value: &Value) -> Value {
63    match value {
64        Value::Object(map) => {
65            if let Some(variants) = map.get("anyOf").and_then(|val| val.as_array()) {
66                let mut cleaned = variants
67                    .iter()
68                    .filter(|variant| !is_null_schema(variant))
69                    .map(sanitize_gemini_schema)
70                    .collect::<Vec<_>>();
71                if cleaned.len() == 1 {
72                    return cleaned.pop().unwrap_or(Value::Null);
73                }
74            }
75            if let Some(variants) = map.get("oneOf").and_then(|val| val.as_array()) {
76                let mut cleaned = variants
77                    .iter()
78                    .filter(|variant| !is_null_schema(variant))
79                    .map(sanitize_gemini_schema)
80                    .collect::<Vec<_>>();
81                if cleaned.len() == 1 {
82                    return cleaned.pop().unwrap_or(Value::Null);
83                }
84            }
85
86            let mut out = Map::new();
87            for (key, val) in map {
88                if matches!(
89                    key.as_str(),
90                    "additionalProperties" | "$schema" | "$id" | "title"
91                ) {
92                    continue;
93                }
94                if key == "type"
95                    && let Value::Array(types) = val
96                {
97                    if let Some(first) = types
98                        .iter()
99                        .find(|item| !matches!(item, Value::String(t) if t == "null"))
100                    {
101                        out.insert(key.clone(), first.clone());
102                    }
103                    continue;
104                }
105                out.insert(key.clone(), sanitize_gemini_schema(val));
106            }
107            Value::Object(out)
108        }
109        Value::Array(items) => Value::Array(items.iter().map(sanitize_gemini_schema).collect()),
110        _ => value.clone(),
111    }
112}
113
114fn infer_media_type_from_url(url: &str) -> Option<String> {
115    let path = url.split('?').next()?;
116    let ext = path.rsplit('.').next()?.to_lowercase();
117    let media_type = match ext.as_str() {
118        "png" => "image/png",
119        "jpg" | "jpeg" => "image/jpeg",
120        "gif" => "image/gif",
121        "webp" => "image/webp",
122        "pdf" => "application/pdf",
123        "txt" => "text/plain",
124        "md" | "markdown" => "text/markdown",
125        "csv" => "text/csv",
126        "json" => "application/json",
127        "mp3" => "audio/mpeg",
128        "wav" => "audio/wav",
129        "ogg" | "oga" => "audio/ogg",
130        "flac" => "audio/flac",
131        "m4a" | "aac" => "audio/aac",
132        "mp4" => "video/mp4",
133        "mov" => "video/quicktime",
134        "webm" => "video/webm",
135        "mkv" => "video/x-matroska",
136        _ => return None,
137    };
138    Some(media_type.to_string())
139}
140
141fn file_data_part(url: &str, media_type: &Option<String>) -> Value {
142    let mut file_data = Map::new();
143    file_data.insert("fileUri".to_string(), Value::String(url.to_string()));
144    let inferred = media_type
145        .clone()
146        .or_else(|| infer_media_type_from_url(url));
147    if let Some(media_type) = inferred {
148        file_data.insert("mimeType".to_string(), Value::String(media_type.clone()));
149    }
150    let mut wrapper = Map::new();
151    wrapper.insert("fileData".to_string(), Value::Object(file_data));
152    Value::Object(wrapper)
153}
154
155#[derive(Clone, Debug)]
156pub struct GeminiProvider {
157    api_key: String,
158    base_url: Url,
159}
160
161impl GeminiProvider {
162    pub fn new(
163        api_key: impl Into<String>,
164        base_url: impl AsRef<str>,
165    ) -> Result<Self, ProviderError> {
166        let url = Url::parse(base_url.as_ref())
167            .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
168        Ok(Self {
169            api_key: api_key.into(),
170            base_url: url,
171        })
172    }
173
174    pub fn from_env() -> Result<Self, ProviderError> {
175        let api_key = std::env::var("GEMINI_API_KEY")
176            .or_else(|_| std::env::var("GOOGLE_API_KEY"))
177            .map_err(|_| ProviderError::MissingApiKey("gemini".to_string()))?;
178        Self::new(api_key, "https://generativelanguage.googleapis.com")
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::messages::{
186        BinaryContent, ImageUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
187        ModelResponsePart, ToolCallPart, ToolReturnPart,
188    };
189    use base64::engine::general_purpose::STANDARD;
190    use serde_json::{Value, json};
191    use std::path::PathBuf;
192
193    fn fixture_bytes(name: &str) -> Vec<u8> {
194        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195            .join("tests")
196            .join("fixtures")
197            .join(name);
198        std::fs::read(path).expect("fixture read")
199    }
200
201    #[test]
202    fn convert_user_content_handles_inline_and_file_data() {
203        let pdf_bytes = fixture_bytes("fixture.pdf");
204        let audio_bytes = fixture_bytes("fixture.m4a");
205
206        let content = vec![
207            UserContent::Binary(BinaryContent {
208                data: pdf_bytes.clone(),
209                media_type: "application/pdf".to_string(),
210            }),
211            UserContent::Binary(BinaryContent {
212                data: audio_bytes.clone(),
213                media_type: "audio/aac".to_string(),
214            }),
215            UserContent::Image(ImageUrl {
216                url: "https://example.com/fixture.jpg".to_string(),
217                media_type: None,
218            }),
219        ];
220
221        let parts = convert_user_content(&content);
222        assert_eq!(parts.len(), 3);
223
224        let pdf = &parts[0];
225        let pdf_inline = pdf.get("inlineData").expect("pdf inline");
226        assert_eq!(
227            pdf_inline.get("mimeType"),
228            Some(&Value::String("application/pdf".to_string()))
229        );
230        assert_eq!(
231            pdf_inline.get("data"),
232            Some(&Value::String(STANDARD.encode(&pdf_bytes)))
233        );
234
235        let audio = &parts[1];
236        let audio_inline = audio.get("inlineData").expect("audio inline");
237        assert_eq!(
238            audio_inline.get("mimeType"),
239            Some(&Value::String("audio/aac".to_string()))
240        );
241        assert_eq!(
242            audio_inline.get("data"),
243            Some(&Value::String(STANDARD.encode(&audio_bytes)))
244        );
245
246        let image = &parts[2];
247        let file_data = image.get("fileData").expect("file data");
248        assert_eq!(
249            file_data.get("fileUri"),
250            Some(&Value::String(
251                "https://example.com/fixture.jpg".to_string()
252            ))
253        );
254        assert_eq!(
255            file_data.get("mimeType"),
256            Some(&Value::String("image/jpeg".to_string()))
257        );
258    }
259
260    #[test]
261    fn split_system_replays_tool_calls() {
262        let messages = vec![
263            ModelMessage::Response(ModelResponse {
264                parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
265                    id: "call-1".to_string(),
266                    name: "get_data".to_string(),
267                    arguments: json!({"a": 1}),
268                })],
269                usage: None,
270                model_name: None,
271                finish_reason: None,
272            }),
273            ModelMessage::Request(ModelRequest {
274                parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
275                    tool_name: "get_data".to_string(),
276                    tool_call_id: "call-1".to_string(),
277                    content: json!({"ok": true}),
278                })],
279                instructions: None,
280            }),
281        ];
282
283        let (_system, contents) = GeminiModel::split_system(&messages);
284        assert_eq!(contents.len(), 2);
285
286        let model_msg = contents[0].as_object().expect("model message");
287        assert_eq!(
288            model_msg.get("role"),
289            Some(&Value::String("model".to_string()))
290        );
291        let model_parts = model_msg
292            .get("parts")
293            .and_then(|value| value.as_array())
294            .expect("model parts");
295        let function_call = model_parts
296            .iter()
297            .find_map(|part| part.get("functionCall"))
298            .expect("functionCall");
299        assert_eq!(
300            function_call.get("name"),
301            Some(&Value::String("get_data".to_string()))
302        );
303        assert_eq!(function_call.get("args"), Some(&json!({"a": 1})));
304
305        let user_msg = contents[1].as_object().expect("user message");
306        assert_eq!(
307            user_msg.get("role"),
308            Some(&Value::String("user".to_string()))
309        );
310        let user_parts = user_msg
311            .get("parts")
312            .and_then(|value| value.as_array())
313            .expect("user parts");
314        let function_response = user_parts
315            .iter()
316            .find_map(|part| part.get("functionResponse"))
317            .expect("functionResponse");
318        assert_eq!(
319            function_response.get("name"),
320            Some(&Value::String("get_data".to_string()))
321        );
322        assert_eq!(
323            function_response.get("response"),
324            Some(&json!({"ok": true}))
325        );
326    }
327
328    #[test]
329    fn helper_functions_cover_schema_and_media() {
330        let wrapped = gemini_response_object(&json!("ok"));
331        assert_eq!(
332            wrapped.get("return_value").and_then(|value| value.as_str()),
333            Some("ok")
334        );
335
336        let schema = json!({
337            "anyOf": [
338                { "type": "null" },
339                { "type": "string" }
340            ],
341            "title": "Example",
342            "additionalProperties": false
343        });
344        let sanitized = sanitize_gemini_schema(&schema);
345        assert_eq!(
346            sanitized.get("type"),
347            Some(&Value::String("string".to_string()))
348        );
349        assert!(sanitized.get("title").is_none());
350
351        assert_eq!(
352            infer_media_type_from_url("https://example.com/file.pdf"),
353            Some("application/pdf".to_string())
354        );
355        assert_eq!(
356            infer_media_type_from_url("https://example.com/file.unknown"),
357            None
358        );
359
360        let part = file_data_part("https://example.com/file.txt", &None);
361        let file_data = part.get("fileData").expect("file data");
362        assert_eq!(
363            file_data.get("mimeType"),
364            Some(&Value::String("text/plain".to_string()))
365        );
366    }
367
368    #[test]
369    fn helper_functions_cover_ids_and_truncation() {
370        let id = normalize_tool_call_id(Some("".to_string()));
371        assert!(id.starts_with("call_"));
372
373        let truncated = truncate_error_body(&"a".repeat(600));
374        assert!(truncated.contains("bytes"));
375    }
376
377    #[test]
378    fn sanitize_gemini_schema_removes_null_type_array() {
379        let schema = json!({
380            "type": ["null", "object"],
381            "properties": {"a": {"type": "string"}},
382            "$schema": "http://json-schema.org/draft-07/schema#"
383        });
384        let sanitized = sanitize_gemini_schema(&schema);
385        assert_eq!(
386            sanitized.get("type"),
387            Some(&Value::String("object".to_string()))
388        );
389        assert!(sanitized.get("$schema").is_none());
390    }
391}
392
393impl Provider for GeminiProvider {
394    fn name(&self) -> &str {
395        "gemini"
396    }
397
398    fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
399        Arc::new(GeminiModel::new(
400            model,
401            self.api_key.clone(),
402            self.base_url.clone(),
403            settings,
404        ))
405    }
406}
407
408#[derive(Clone, Debug)]
409pub struct GeminiModel {
410    model: String,
411    api_key: String,
412    base_url: Url,
413    client: Client,
414    default_settings: Option<ModelSettings>,
415}
416
417impl GeminiModel {
418    pub fn new(
419        model: impl Into<String>,
420        api_key: String,
421        base_url: Url,
422        settings: Option<ModelSettings>,
423    ) -> Self {
424        let mut model = model.into();
425        if !model.starts_with("models/") {
426            model = format!("models/{model}");
427        }
428        Self {
429            model,
430            api_key,
431            base_url,
432            client: Client::new(),
433            default_settings: settings,
434        }
435    }
436
437    fn endpoint(&self) -> Result<Url, ModelError> {
438        let path = format!("v1beta/{}:generateContent", self.model);
439        let mut url = self
440            .base_url
441            .join(&path)
442            .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))?;
443        url.query_pairs_mut().append_pair("key", &self.api_key);
444        Ok(url)
445    }
446
447    fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
448        let mut system_parts = Vec::new();
449        let mut contents = Vec::new();
450
451        for message in messages {
452            match message {
453                ModelMessage::Request(req) => {
454                    if let Some(instructions) = req
455                        .instructions
456                        .as_ref()
457                        .filter(|value| !value.trim().is_empty())
458                    {
459                        system_parts.push(instructions.to_string());
460                    }
461                    for part in &req.parts {
462                        match part {
463                            ModelRequestPart::SystemPrompt(prompt) => {
464                                system_parts.push(prompt.content.clone());
465                            }
466                            ModelRequestPart::UserPrompt(prompt) => contents.push(json!({
467                                "role": "user",
468                                "parts": convert_user_content(&prompt.content)
469                            })),
470                            ModelRequestPart::ToolReturn(tool_return) => contents.push(json!({
471                                "role": "user",
472                                "parts": [{
473                                    "functionResponse": {
474                                        "name": tool_return.tool_name,
475                                        "response": gemini_response_object(&tool_return.content),
476                                    }
477                                }]
478                            })),
479                            ModelRequestPart::RetryPrompt(retry) => {
480                                let parts = if let Some(tool_name) = &retry.tool_name {
481                                    vec![json!({
482                                        "functionResponse": {
483                                            "name": tool_name,
484                                            "response": {"call_error": retry.content}
485                                        }
486                                    })]
487                                } else {
488                                    vec![json!({"text": retry.content})]
489                                };
490                                contents.push(json!({
491                                    "role": "user",
492                                    "parts": parts
493                                }));
494                            }
495                        }
496                    }
497                }
498                ModelMessage::Response(res) => {
499                    let mut parts = Vec::new();
500                    if let Some(text) = res.text() {
501                        parts.push(json!({"text": text}));
502                    }
503                    for call in res.tool_calls() {
504                        parts.push(json!({
505                            "functionCall": {
506                                "name": call.name,
507                                "args": call.arguments,
508                            }
509                        }));
510                    }
511
512                    if !parts.is_empty() {
513                        contents.push(json!({
514                            "role": "model",
515                            "parts": parts
516                        }));
517                    }
518                }
519            }
520        }
521
522        let system = if system_parts.is_empty() {
523            None
524        } else {
525            Some(system_parts.join("\n\n"))
526        };
527
528        (system, contents)
529    }
530}
531
532fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
533    let mut parts = Vec::new();
534    for item in content {
535        match item {
536            UserContent::Text(text) => parts.push(json!({"text": text})),
537            UserContent::Image(image) => parts.push(file_data_part(&image.url, &image.media_type)),
538            UserContent::Video(video) => parts.push(file_data_part(&video.url, &video.media_type)),
539            UserContent::Audio(audio) => parts.push(file_data_part(&audio.url, &audio.media_type)),
540            UserContent::Document(doc) => parts.push(file_data_part(&doc.url, &doc.media_type)),
541            UserContent::Binary(binary) => parts.push(json!({
542                "inlineData": {
543                    "mimeType": binary.media_type,
544                    "data": base64::engine::general_purpose::STANDARD.encode(&binary.data)
545                }
546            })),
547        }
548    }
549    parts
550}
551
552#[async_trait]
553impl Model for GeminiModel {
554    fn name(&self) -> &str {
555        &self.model
556    }
557
558    async fn request(
559        &self,
560        messages: &[ModelMessage],
561        settings: Option<&ModelSettings>,
562        params: &ModelRequestParameters,
563    ) -> Result<ModelResponse, ModelError> {
564        tracing::debug!(
565            model = %self.model,
566            tool_count = params.function_tools.len(),
567            output_schema = params.output_schema.is_some(),
568            "Gemini request"
569        );
570        let (system, contents) = Self::split_system(messages);
571        let mut body = Map::new();
572        body.insert("contents".to_string(), Value::Array(contents));
573        if let Some(system) = system {
574            body.insert(
575                "systemInstruction".to_string(),
576                json!({"parts": [{"text": system}]}),
577            );
578        }
579
580        if !params.function_tools.is_empty() {
581            let tools = params
582                .function_tools
583                .iter()
584                .map(|tool| {
585                    let schema = sanitize_gemini_schema(&tool.parameters_json_schema);
586                    json!({
587                        "name": tool.name,
588                        "description": tool.description,
589                        "parameters": schema,
590                    })
591                })
592                .collect::<Vec<_>>();
593            body.insert(
594                "tools".to_string(),
595                json!([{ "functionDeclarations": tools }]),
596            );
597            body.insert(
598                "toolConfig".to_string(),
599                json!({"functionCallingConfig": {"mode": "AUTO"}}),
600            );
601        }
602
603        if params.output_mode == OutputMode::JsonSchema
604            && let Some(schema) = params.output_schema.clone()
605        {
606            let schema = sanitize_gemini_schema(&schema);
607            body.insert(
608                "generationConfig".to_string(),
609                json!({
610                    "responseMimeType": "application/json",
611                    "responseSchema": schema
612                }),
613            );
614        }
615
616        if let Some(settings) = &self.default_settings {
617            for (key, value) in settings {
618                body.insert(key.clone(), value.clone());
619            }
620        }
621
622        if let Some(settings) = settings {
623            for (key, value) in settings {
624                body.insert(key.clone(), value.clone());
625            }
626        }
627
628        let response = self
629            .client
630            .post(self.endpoint()?)
631            .json(&Value::Object(body))
632            .send()
633            .await
634            .map_err(|e| map_reqwest_error("Gemini", e))?;
635
636        let status = response.status();
637        if !status.is_success() {
638            let body = response.text().await.unwrap_or_default();
639            tracing::error!(
640                status = status.as_u16(),
641                model = %self.model,
642                body = %truncate_error_body(&body),
643                "Gemini request failed"
644            );
645            return Err(ModelError::HttpStatus {
646                status: status.as_u16(),
647            });
648        }
649
650        let body: GeminiResponse = response.json().await.map_err(|e| {
651            tracing::error!(
652                error = %e,
653                model = %self.model,
654                "Gemini response parse failed"
655            );
656            ModelError::Provider(format!("Gemini response parse failed: {e}"))
657        })?;
658
659        let candidate = body.candidates.into_iter().next().ok_or_else(|| {
660            tracing::error!(model = %self.model, "Gemini response missing candidates");
661            ModelError::Provider("Gemini response missing candidates".to_string())
662        })?;
663
664        let mut parts = Vec::new();
665        if let Some(content) = candidate.content {
666            for part in content.parts {
667                if let Some(text) = part.text {
668                    parts.push(ModelResponsePart::Text(TextPart { content: text }));
669                }
670                if let Some(call) = part.function_call {
671                    parts.push(ModelResponsePart::ToolCall(ToolCallPart {
672                        id: normalize_tool_call_id(call.id),
673                        name: call.name.unwrap_or_else(|| "tool".to_string()),
674                        arguments: call.args.unwrap_or_else(|| Value::Object(Map::new())),
675                    }));
676                }
677            }
678        }
679
680        let usage = body.usage_metadata.map(|usage| RequestUsage {
681            input_tokens: usage.prompt_token_count.unwrap_or(0),
682            output_tokens: usage.candidates_token_count.unwrap_or(0),
683            ..Default::default()
684        });
685
686        Ok(ModelResponse {
687            parts,
688            usage,
689            model_name: Some(self.model.clone()),
690            finish_reason: candidate.finish_reason,
691        })
692    }
693}
694
695#[derive(Debug, Deserialize)]
696struct GeminiResponse {
697    candidates: Vec<GeminiCandidate>,
698    #[serde(rename = "usageMetadata")]
699    usage_metadata: Option<GeminiUsage>,
700}
701
702#[derive(Debug, Deserialize)]
703struct GeminiCandidate {
704    content: Option<GeminiContent>,
705    #[serde(rename = "finishReason")]
706    finish_reason: Option<String>,
707}
708
709#[derive(Debug, Deserialize)]
710struct GeminiContent {
711    parts: Vec<GeminiPart>,
712}
713
714#[derive(Debug, Deserialize)]
715struct GeminiPart {
716    text: Option<String>,
717    #[serde(rename = "functionCall")]
718    function_call: Option<GeminiFunctionCall>,
719}
720
721#[derive(Debug, Deserialize)]
722struct GeminiFunctionCall {
723    id: Option<String>,
724    name: Option<String>,
725    args: Option<Value>,
726}
727
728#[derive(Debug, Deserialize)]
729struct GeminiUsage {
730    #[serde(rename = "promptTokenCount")]
731    prompt_token_count: Option<u64>,
732    #[serde(rename = "candidatesTokenCount")]
733    candidates_token_count: Option<u64>,
734}