1use std::sync::Arc;
2
3use async_trait::async_trait;
4use base64::Engine;
5use reqwest::{Client, Url};
6use serde::Deserialize;
7use serde_json::{Map, Value, json};
8use uuid::Uuid;
9
10use crate::messages::{
11 ModelMessage, ModelRequestPart, ModelResponse, ModelResponsePart, TextPart, ToolCallPart,
12 UserContent,
13};
14use crate::model::{Model, ModelError, ModelRequestParameters, ModelSettings, OutputMode};
15use crate::providers::{Provider, ProviderError};
16use crate::usage::RequestUsage;
17
18fn map_reqwest_error(label: &str, error: reqwest::Error) -> ModelError {
19 if error.is_timeout() {
20 return ModelError::Timeout;
21 }
22 if error.is_connect() {
23 return ModelError::Transport(format!("{label} connect error: {error}"));
24 }
25 ModelError::Transport(format!("{label} request failed: {error}"))
26}
27
28fn truncate_error_body(body: &str) -> String {
29 const LIMIT: usize = 512;
30 if body.len() <= LIMIT {
31 body.to_string()
32 } else {
33 format!("{}... ({} bytes)", &body[..LIMIT], body.len())
34 }
35}
36
37fn normalize_tool_call_id(id: Option<String>) -> String {
38 match id {
39 Some(value) if !value.trim().is_empty() => value,
40 _ => format!("call_{}", Uuid::new_v4().simple()),
41 }
42}
43
44fn gemini_response_object(value: &Value) -> Value {
45 match value {
46 Value::Object(_) => value.clone(),
47 _ => {
48 let mut wrapped = Map::new();
49 wrapped.insert("return_value".to_string(), value.clone());
50 Value::Object(wrapped)
51 }
52 }
53}
54
55fn is_null_schema(value: &Value) -> bool {
56 matches!(
57 value,
58 Value::Object(map) if matches!(map.get("type"), Some(Value::String(t)) if t == "null")
59 )
60}
61
62fn sanitize_gemini_schema(value: &Value) -> Value {
63 match value {
64 Value::Object(map) => {
65 if let Some(variants) = map.get("anyOf").and_then(|val| val.as_array()) {
66 let mut cleaned = variants
67 .iter()
68 .filter(|variant| !is_null_schema(variant))
69 .map(sanitize_gemini_schema)
70 .collect::<Vec<_>>();
71 if cleaned.len() == 1 {
72 return cleaned.pop().unwrap_or(Value::Null);
73 }
74 }
75 if let Some(variants) = map.get("oneOf").and_then(|val| val.as_array()) {
76 let mut cleaned = variants
77 .iter()
78 .filter(|variant| !is_null_schema(variant))
79 .map(sanitize_gemini_schema)
80 .collect::<Vec<_>>();
81 if cleaned.len() == 1 {
82 return cleaned.pop().unwrap_or(Value::Null);
83 }
84 }
85
86 let mut out = Map::new();
87 for (key, val) in map {
88 if matches!(
89 key.as_str(),
90 "additionalProperties" | "$schema" | "$id" | "title"
91 ) {
92 continue;
93 }
94 if key == "type"
95 && let Value::Array(types) = val
96 {
97 if let Some(first) = types
98 .iter()
99 .find(|item| !matches!(item, Value::String(t) if t == "null"))
100 {
101 out.insert(key.clone(), first.clone());
102 }
103 continue;
104 }
105 out.insert(key.clone(), sanitize_gemini_schema(val));
106 }
107 Value::Object(out)
108 }
109 Value::Array(items) => Value::Array(items.iter().map(sanitize_gemini_schema).collect()),
110 _ => value.clone(),
111 }
112}
113
114fn infer_media_type_from_url(url: &str) -> Option<String> {
115 let path = url.split('?').next()?;
116 let ext = path.rsplit('.').next()?.to_lowercase();
117 let media_type = match ext.as_str() {
118 "png" => "image/png",
119 "jpg" | "jpeg" => "image/jpeg",
120 "gif" => "image/gif",
121 "webp" => "image/webp",
122 "pdf" => "application/pdf",
123 "txt" => "text/plain",
124 "md" | "markdown" => "text/markdown",
125 "csv" => "text/csv",
126 "json" => "application/json",
127 "mp3" => "audio/mpeg",
128 "wav" => "audio/wav",
129 "ogg" | "oga" => "audio/ogg",
130 "flac" => "audio/flac",
131 "m4a" | "aac" => "audio/aac",
132 "mp4" => "video/mp4",
133 "mov" => "video/quicktime",
134 "webm" => "video/webm",
135 "mkv" => "video/x-matroska",
136 _ => return None,
137 };
138 Some(media_type.to_string())
139}
140
141fn file_data_part(url: &str, media_type: &Option<String>) -> Value {
142 let mut file_data = Map::new();
143 file_data.insert("fileUri".to_string(), Value::String(url.to_string()));
144 let inferred = media_type
145 .clone()
146 .or_else(|| infer_media_type_from_url(url));
147 if let Some(media_type) = inferred {
148 file_data.insert("mimeType".to_string(), Value::String(media_type.clone()));
149 }
150 let mut wrapper = Map::new();
151 wrapper.insert("fileData".to_string(), Value::Object(file_data));
152 Value::Object(wrapper)
153}
154
155#[derive(Clone, Debug)]
156pub struct GeminiProvider {
157 api_key: String,
158 base_url: Url,
159}
160
161impl GeminiProvider {
162 pub fn new(
163 api_key: impl Into<String>,
164 base_url: impl AsRef<str>,
165 ) -> Result<Self, ProviderError> {
166 let url = Url::parse(base_url.as_ref())
167 .map_err(|_| ProviderError::InvalidModel(base_url.as_ref().to_string()))?;
168 Ok(Self {
169 api_key: api_key.into(),
170 base_url: url,
171 })
172 }
173
174 pub fn from_env() -> Result<Self, ProviderError> {
175 let api_key = std::env::var("GEMINI_API_KEY")
176 .or_else(|_| std::env::var("GOOGLE_API_KEY"))
177 .map_err(|_| ProviderError::MissingApiKey("gemini".to_string()))?;
178 Self::new(api_key, "https://generativelanguage.googleapis.com")
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::messages::{
186 BinaryContent, ImageUrl, ModelMessage, ModelRequest, ModelRequestPart, ModelResponse,
187 ModelResponsePart, ToolCallPart, ToolReturnPart,
188 };
189 use base64::engine::general_purpose::STANDARD;
190 use serde_json::{Value, json};
191 use std::path::PathBuf;
192
193 fn fixture_bytes(name: &str) -> Vec<u8> {
194 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195 .join("tests")
196 .join("fixtures")
197 .join(name);
198 std::fs::read(path).expect("fixture read")
199 }
200
201 #[test]
202 fn convert_user_content_handles_inline_and_file_data() {
203 let pdf_bytes = fixture_bytes("fixture.pdf");
204 let audio_bytes = fixture_bytes("fixture.m4a");
205
206 let content = vec![
207 UserContent::Binary(BinaryContent {
208 data: pdf_bytes.clone(),
209 media_type: "application/pdf".to_string(),
210 }),
211 UserContent::Binary(BinaryContent {
212 data: audio_bytes.clone(),
213 media_type: "audio/aac".to_string(),
214 }),
215 UserContent::Image(ImageUrl {
216 url: "https://example.com/fixture.jpg".to_string(),
217 media_type: None,
218 }),
219 ];
220
221 let parts = convert_user_content(&content);
222 assert_eq!(parts.len(), 3);
223
224 let pdf = &parts[0];
225 let pdf_inline = pdf.get("inlineData").expect("pdf inline");
226 assert_eq!(
227 pdf_inline.get("mimeType"),
228 Some(&Value::String("application/pdf".to_string()))
229 );
230 assert_eq!(
231 pdf_inline.get("data"),
232 Some(&Value::String(STANDARD.encode(&pdf_bytes)))
233 );
234
235 let audio = &parts[1];
236 let audio_inline = audio.get("inlineData").expect("audio inline");
237 assert_eq!(
238 audio_inline.get("mimeType"),
239 Some(&Value::String("audio/aac".to_string()))
240 );
241 assert_eq!(
242 audio_inline.get("data"),
243 Some(&Value::String(STANDARD.encode(&audio_bytes)))
244 );
245
246 let image = &parts[2];
247 let file_data = image.get("fileData").expect("file data");
248 assert_eq!(
249 file_data.get("fileUri"),
250 Some(&Value::String(
251 "https://example.com/fixture.jpg".to_string()
252 ))
253 );
254 assert_eq!(
255 file_data.get("mimeType"),
256 Some(&Value::String("image/jpeg".to_string()))
257 );
258 }
259
260 #[test]
261 fn split_system_replays_tool_calls() {
262 let messages = vec![
263 ModelMessage::Response(ModelResponse {
264 parts: vec![ModelResponsePart::ToolCall(ToolCallPart {
265 id: "call-1".to_string(),
266 name: "get_data".to_string(),
267 arguments: json!({"a": 1}),
268 })],
269 usage: None,
270 model_name: None,
271 finish_reason: None,
272 }),
273 ModelMessage::Request(ModelRequest {
274 parts: vec![ModelRequestPart::ToolReturn(ToolReturnPart {
275 tool_name: "get_data".to_string(),
276 tool_call_id: "call-1".to_string(),
277 content: json!({"ok": true}),
278 })],
279 instructions: None,
280 }),
281 ];
282
283 let (_system, contents) = GeminiModel::split_system(&messages);
284 assert_eq!(contents.len(), 2);
285
286 let model_msg = contents[0].as_object().expect("model message");
287 assert_eq!(
288 model_msg.get("role"),
289 Some(&Value::String("model".to_string()))
290 );
291 let model_parts = model_msg
292 .get("parts")
293 .and_then(|value| value.as_array())
294 .expect("model parts");
295 let function_call = model_parts
296 .iter()
297 .find_map(|part| part.get("functionCall"))
298 .expect("functionCall");
299 assert_eq!(
300 function_call.get("name"),
301 Some(&Value::String("get_data".to_string()))
302 );
303 assert_eq!(function_call.get("args"), Some(&json!({"a": 1})));
304
305 let user_msg = contents[1].as_object().expect("user message");
306 assert_eq!(
307 user_msg.get("role"),
308 Some(&Value::String("user".to_string()))
309 );
310 let user_parts = user_msg
311 .get("parts")
312 .and_then(|value| value.as_array())
313 .expect("user parts");
314 let function_response = user_parts
315 .iter()
316 .find_map(|part| part.get("functionResponse"))
317 .expect("functionResponse");
318 assert_eq!(
319 function_response.get("name"),
320 Some(&Value::String("get_data".to_string()))
321 );
322 assert_eq!(
323 function_response.get("response"),
324 Some(&json!({"ok": true}))
325 );
326 }
327}
328
329impl Provider for GeminiProvider {
330 fn name(&self) -> &str {
331 "gemini"
332 }
333
334 fn model(&self, model: &str, settings: Option<ModelSettings>) -> Arc<dyn Model> {
335 Arc::new(GeminiModel::new(
336 model,
337 self.api_key.clone(),
338 self.base_url.clone(),
339 settings,
340 ))
341 }
342}
343
344#[derive(Clone, Debug)]
345pub struct GeminiModel {
346 model: String,
347 api_key: String,
348 base_url: Url,
349 client: Client,
350 default_settings: Option<ModelSettings>,
351}
352
353impl GeminiModel {
354 pub fn new(
355 model: impl Into<String>,
356 api_key: String,
357 base_url: Url,
358 settings: Option<ModelSettings>,
359 ) -> Self {
360 let mut model = model.into();
361 if !model.starts_with("models/") {
362 model = format!("models/{model}");
363 }
364 Self {
365 model,
366 api_key,
367 base_url,
368 client: Client::new(),
369 default_settings: settings,
370 }
371 }
372
373 fn endpoint(&self) -> Result<Url, ModelError> {
374 let path = format!("v1beta/{}:generateContent", self.model);
375 let mut url = self
376 .base_url
377 .join(&path)
378 .map_err(|e| ModelError::Provider(format!("invalid base url: {e}")))?;
379 url.query_pairs_mut().append_pair("key", &self.api_key);
380 Ok(url)
381 }
382
383 fn split_system(messages: &[ModelMessage]) -> (Option<String>, Vec<Value>) {
384 let mut system_parts = Vec::new();
385 let mut contents = Vec::new();
386
387 for message in messages {
388 match message {
389 ModelMessage::Request(req) => {
390 if let Some(instructions) = req
391 .instructions
392 .as_ref()
393 .filter(|value| !value.trim().is_empty())
394 {
395 system_parts.push(instructions.to_string());
396 }
397 for part in &req.parts {
398 match part {
399 ModelRequestPart::SystemPrompt(prompt) => {
400 system_parts.push(prompt.content.clone());
401 }
402 ModelRequestPart::UserPrompt(prompt) => contents.push(json!({
403 "role": "user",
404 "parts": convert_user_content(&prompt.content)
405 })),
406 ModelRequestPart::ToolReturn(tool_return) => contents.push(json!({
407 "role": "user",
408 "parts": [{
409 "functionResponse": {
410 "name": tool_return.tool_name,
411 "response": gemini_response_object(&tool_return.content),
412 }
413 }]
414 })),
415 ModelRequestPart::RetryPrompt(retry) => {
416 let parts = if let Some(tool_name) = &retry.tool_name {
417 vec![json!({
418 "functionResponse": {
419 "name": tool_name,
420 "response": {"call_error": retry.content}
421 }
422 })]
423 } else {
424 vec![json!({"text": retry.content})]
425 };
426 contents.push(json!({
427 "role": "user",
428 "parts": parts
429 }));
430 }
431 }
432 }
433 }
434 ModelMessage::Response(res) => {
435 let mut parts = Vec::new();
436 if let Some(text) = res.text() {
437 parts.push(json!({"text": text}));
438 }
439 for call in res.tool_calls() {
440 parts.push(json!({
441 "functionCall": {
442 "name": call.name,
443 "args": call.arguments,
444 }
445 }));
446 }
447
448 if !parts.is_empty() {
449 contents.push(json!({
450 "role": "model",
451 "parts": parts
452 }));
453 }
454 }
455 }
456 }
457
458 let system = if system_parts.is_empty() {
459 None
460 } else {
461 Some(system_parts.join("\n\n"))
462 };
463
464 (system, contents)
465 }
466}
467
468fn convert_user_content(content: &[UserContent]) -> Vec<Value> {
469 let mut parts = Vec::new();
470 for item in content {
471 match item {
472 UserContent::Text(text) => parts.push(json!({"text": text})),
473 UserContent::Image(image) => parts.push(file_data_part(&image.url, &image.media_type)),
474 UserContent::Video(video) => parts.push(file_data_part(&video.url, &video.media_type)),
475 UserContent::Audio(audio) => parts.push(file_data_part(&audio.url, &audio.media_type)),
476 UserContent::Document(doc) => parts.push(file_data_part(&doc.url, &doc.media_type)),
477 UserContent::Binary(binary) => parts.push(json!({
478 "inlineData": {
479 "mimeType": binary.media_type,
480 "data": base64::engine::general_purpose::STANDARD.encode(&binary.data)
481 }
482 })),
483 }
484 }
485 parts
486}
487
488#[async_trait]
489impl Model for GeminiModel {
490 fn name(&self) -> &str {
491 &self.model
492 }
493
494 async fn request(
495 &self,
496 messages: &[ModelMessage],
497 settings: Option<&ModelSettings>,
498 params: &ModelRequestParameters,
499 ) -> Result<ModelResponse, ModelError> {
500 tracing::debug!(
501 model = %self.model,
502 tool_count = params.function_tools.len(),
503 output_schema = params.output_schema.is_some(),
504 "Gemini request"
505 );
506 let (system, contents) = Self::split_system(messages);
507 let mut body = Map::new();
508 body.insert("contents".to_string(), Value::Array(contents));
509 if let Some(system) = system {
510 body.insert(
511 "systemInstruction".to_string(),
512 json!({"parts": [{"text": system}]}),
513 );
514 }
515
516 if !params.function_tools.is_empty() {
517 let tools = params
518 .function_tools
519 .iter()
520 .map(|tool| {
521 let schema = sanitize_gemini_schema(&tool.parameters_json_schema);
522 json!({
523 "name": tool.name,
524 "description": tool.description,
525 "parameters": schema,
526 })
527 })
528 .collect::<Vec<_>>();
529 body.insert(
530 "tools".to_string(),
531 json!([{ "functionDeclarations": tools }]),
532 );
533 body.insert(
534 "toolConfig".to_string(),
535 json!({"functionCallingConfig": {"mode": "AUTO"}}),
536 );
537 }
538
539 if params.output_mode == OutputMode::JsonSchema
540 && let Some(schema) = params.output_schema.clone()
541 {
542 let schema = sanitize_gemini_schema(&schema);
543 body.insert(
544 "generationConfig".to_string(),
545 json!({
546 "responseMimeType": "application/json",
547 "responseSchema": schema
548 }),
549 );
550 }
551
552 if let Some(settings) = &self.default_settings {
553 for (key, value) in settings {
554 body.insert(key.clone(), value.clone());
555 }
556 }
557
558 if let Some(settings) = settings {
559 for (key, value) in settings {
560 body.insert(key.clone(), value.clone());
561 }
562 }
563
564 let response = self
565 .client
566 .post(self.endpoint()?)
567 .json(&Value::Object(body))
568 .send()
569 .await
570 .map_err(|e| map_reqwest_error("Gemini", e))?;
571
572 let status = response.status();
573 if !status.is_success() {
574 let body = response.text().await.unwrap_or_default();
575 tracing::error!(
576 status = status.as_u16(),
577 model = %self.model,
578 body = %truncate_error_body(&body),
579 "Gemini request failed"
580 );
581 return Err(ModelError::HttpStatus {
582 status: status.as_u16(),
583 });
584 }
585
586 let body: GeminiResponse = response.json().await.map_err(|e| {
587 tracing::error!(
588 error = %e,
589 model = %self.model,
590 "Gemini response parse failed"
591 );
592 ModelError::Provider(format!("Gemini response parse failed: {e}"))
593 })?;
594
595 let candidate = body.candidates.into_iter().next().ok_or_else(|| {
596 tracing::error!(model = %self.model, "Gemini response missing candidates");
597 ModelError::Provider("Gemini response missing candidates".to_string())
598 })?;
599
600 let mut parts = Vec::new();
601 if let Some(content) = candidate.content {
602 for part in content.parts {
603 if let Some(text) = part.text {
604 parts.push(ModelResponsePart::Text(TextPart { content: text }));
605 }
606 if let Some(call) = part.function_call {
607 parts.push(ModelResponsePart::ToolCall(ToolCallPart {
608 id: normalize_tool_call_id(call.id),
609 name: call.name.unwrap_or_else(|| "tool".to_string()),
610 arguments: call.args.unwrap_or_else(|| Value::Object(Map::new())),
611 }));
612 }
613 }
614 }
615
616 let usage = body.usage_metadata.map(|usage| RequestUsage {
617 input_tokens: usage.prompt_token_count.unwrap_or(0),
618 output_tokens: usage.candidates_token_count.unwrap_or(0),
619 ..Default::default()
620 });
621
622 Ok(ModelResponse {
623 parts,
624 usage,
625 model_name: Some(self.model.clone()),
626 finish_reason: candidate.finish_reason,
627 })
628 }
629}
630
631#[derive(Debug, Deserialize)]
632struct GeminiResponse {
633 candidates: Vec<GeminiCandidate>,
634 #[serde(rename = "usageMetadata")]
635 usage_metadata: Option<GeminiUsage>,
636}
637
638#[derive(Debug, Deserialize)]
639struct GeminiCandidate {
640 content: Option<GeminiContent>,
641 #[serde(rename = "finishReason")]
642 finish_reason: Option<String>,
643}
644
645#[derive(Debug, Deserialize)]
646struct GeminiContent {
647 parts: Vec<GeminiPart>,
648}
649
650#[derive(Debug, Deserialize)]
651struct GeminiPart {
652 text: Option<String>,
653 #[serde(rename = "functionCall")]
654 function_call: Option<GeminiFunctionCall>,
655}
656
657#[derive(Debug, Deserialize)]
658struct GeminiFunctionCall {
659 id: Option<String>,
660 name: Option<String>,
661 args: Option<Value>,
662}
663
664#[derive(Debug, Deserialize)]
665struct GeminiUsage {
666 #[serde(rename = "promptTokenCount")]
667 prompt_token_count: Option<u64>,
668 #[serde(rename = "candidatesTokenCount")]
669 candidates_token_count: Option<u64>,
670}