strands_agents/models/
gemini.rs

1//! Google Gemini model provider.
2//!
3//! Docs: https://ai.google.dev/api
4
5use std::collections::HashMap;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::models::{Model, ModelConfig, StreamEventStream};
11use crate::types::content::{Message, SystemContentBlock};
12use crate::types::errors::StrandsError;
13use crate::types::streaming::StreamEvent;
14use crate::types::tools::{ToolChoice, ToolSpec};
15
16/// Configuration for Gemini models.
17#[derive(Debug, Clone, Default)]
18pub struct GeminiConfig {
19    /// Gemini model ID (e.g., "gemini-2.5-flash").
20    pub model_id: String,
21    /// Additional model parameters (e.g., temperature).
22    pub params: HashMap<String, serde_json::Value>,
23    /// API key for authentication.
24    pub api_key: Option<String>,
25    /// Base URL for the API.
26    pub base_url: Option<String>,
27}
28
29impl GeminiConfig {
30    pub fn new(model_id: impl Into<String>) -> Self {
31        Self {
32            model_id: model_id.into(),
33            ..Default::default()
34        }
35    }
36
37    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
38        self.api_key = Some(api_key.into());
39        self
40    }
41
42    pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
43        self.params.insert(key.into(), value);
44        self
45    }
46}
47
48/// Gemini API request format.
49#[derive(Debug, Serialize)]
50#[serde(rename_all = "camelCase")]
51struct GeminiRequest {
52    contents: Vec<GeminiContent>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    system_instruction: Option<GeminiContent>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    tools: Option<Vec<GeminiTool>>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    generation_config: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct GeminiContent {
63    role: String,
64    parts: Vec<GeminiPart>,
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(untagged)]
69enum GeminiPart {
70    Text { text: String },
71    FunctionCall { function_call: GeminiFunctionCall },
72    FunctionResponse { function_response: GeminiFunctionResponse },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct GeminiFunctionCall {
77    name: String,
78    args: serde_json::Value,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct GeminiFunctionResponse {
83    name: String,
84    response: serde_json::Value,
85}
86
87#[derive(Debug, Serialize)]
88struct GeminiTool {
89    function_declarations: Vec<GeminiFunctionDeclaration>,
90}
91
92#[derive(Debug, Serialize)]
93struct GeminiFunctionDeclaration {
94    name: String,
95    description: String,
96    parameters: serde_json::Value,
97}
98
99/// Google Gemini model provider.
100pub struct GeminiModel {
101    config: ModelConfig,
102    gemini_config: GeminiConfig,
103    client: reqwest::Client,
104}
105
106impl GeminiModel {
107    const DEFAULT_BASE_URL: &'static str = "https://generativelanguage.googleapis.com/v1beta";
108
109    pub fn new(config: GeminiConfig) -> Self {
110        let model_config = ModelConfig::new(&config.model_id);
111
112        Self {
113            config: model_config,
114            gemini_config: config,
115            client: reqwest::Client::new(),
116        }
117    }
118
119    fn base_url(&self) -> &str {
120        self.gemini_config
121            .base_url
122            .as_deref()
123            .unwrap_or(Self::DEFAULT_BASE_URL)
124    }
125
126    fn api_key(&self) -> Result<&str, StrandsError> {
127        self.gemini_config
128            .api_key
129            .as_deref()
130            .or_else(|| std::env::var("GOOGLE_API_KEY").ok().as_deref().map(|_| ""))
131            .ok_or_else(|| StrandsError::ConfigurationError {
132                message: "Gemini API key not configured. Set GOOGLE_API_KEY or provide api_key".into(),
133            })
134    }
135
136    fn convert_messages(&self, messages: &[Message]) -> Vec<GeminiContent> {
137        messages
138            .iter()
139            .map(|msg| {
140                let role = match msg.role {
141                    crate::types::content::Role::User => "user",
142                    crate::types::content::Role::Assistant => "model",
143                };
144
145                let parts: Vec<GeminiPart> = msg
146                    .content
147                    .iter()
148                    .filter_map(|block| {
149                        if let Some(text) = &block.text {
150                            Some(GeminiPart::Text { text: text.clone() })
151                        } else if let Some(tool_use) = &block.tool_use {
152                            Some(GeminiPart::FunctionCall {
153                                function_call: GeminiFunctionCall {
154                                    name: tool_use.name.clone(),
155                                    args: tool_use.input.clone(),
156                                },
157                            })
158                        } else if let Some(tool_result) = &block.tool_result {
159                            Some(GeminiPart::FunctionResponse {
160                                function_response: GeminiFunctionResponse {
161                                    name: tool_result.tool_use_id.clone(),
162                                    response: serde_json::json!({
163                                        "content": tool_result.content
164                                    }),
165                                },
166                            })
167                        } else {
168                            None
169                        }
170                    })
171                    .collect();
172
173                GeminiContent {
174                    role: role.to_string(),
175                    parts,
176                }
177            })
178            .collect()
179    }
180
181    fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<GeminiTool> {
182        let declarations: Vec<GeminiFunctionDeclaration> = tool_specs
183            .iter()
184            .map(|spec| GeminiFunctionDeclaration {
185                name: spec.name.clone(),
186                description: spec.description.clone(),
187                parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
188            })
189            .collect();
190
191        vec![GeminiTool {
192            function_declarations: declarations,
193        }]
194    }
195}
196
197#[async_trait]
198impl Model for GeminiModel {
199    fn config(&self) -> &ModelConfig {
200        &self.config
201    }
202
203    fn update_config(&mut self, config: ModelConfig) {
204        self.config = config;
205    }
206
207    fn stream<'a>(
208        &'a self,
209        messages: &'a [Message],
210        tool_specs: Option<&'a [ToolSpec]>,
211        system_prompt: Option<&'a str>,
212        _tool_choice: Option<ToolChoice>,
213        _system_prompt_content: Option<&'a [SystemContentBlock]>,
214    ) -> StreamEventStream<'a> {
215        let messages = messages.to_vec();
216        let tool_specs = tool_specs.map(|t| t.to_vec());
217        let system_prompt = system_prompt.map(|s| s.to_string());
218
219        Box::pin(async_stream::stream! {
220            let api_key = match self.api_key() {
221                Ok(key) => key.to_string(),
222                Err(e) => {
223                    yield Err(e);
224                    return;
225                }
226            };
227
228            let api_key = if api_key.is_empty() {
229                match std::env::var("GOOGLE_API_KEY") {
230                    Ok(key) => key,
231                    Err(_) => {
232                        yield Err(StrandsError::ConfigurationError {
233                            message: "GOOGLE_API_KEY not set".into(),
234                        });
235                        return;
236                    }
237                }
238            } else {
239                api_key
240            };
241
242            let contents = self.convert_messages(&messages);
243
244            let system_instruction = system_prompt.map(|prompt| GeminiContent {
245                role: "user".to_string(),
246                parts: vec![GeminiPart::Text { text: prompt }],
247            });
248
249            let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
250
251            let request = GeminiRequest {
252                contents,
253                system_instruction,
254                tools,
255                generation_config: if self.gemini_config.params.is_empty() {
256                    None
257                } else {
258                    Some(serde_json::to_value(&self.gemini_config.params).unwrap_or_default())
259                },
260            };
261
262            let url = format!(
263                "{}/models/{}:streamGenerateContent?key={}&alt=sse",
264                self.base_url(),
265                self.config.model_id,
266                api_key
267            );
268
269            let response = match self.client
270                .post(&url)
271                .json(&request)
272                .send()
273                .await
274            {
275                Ok(resp) => resp,
276                Err(e) => {
277                    yield Err(StrandsError::NetworkError(e.to_string()));
278                    return;
279                }
280            };
281
282            if !response.status().is_success() {
283                let status = response.status();
284                let body = response.text().await.unwrap_or_default();
285
286                if status.as_u16() == 429 {
287                    yield Err(StrandsError::ModelThrottled {
288                        message: "Gemini rate limit exceeded".into(),
289                    });
290                } else {
291                    yield Err(StrandsError::ModelError {
292                        message: format!("Gemini API error {}: {}", status, body),
293                        source: None,
294                    });
295                }
296                return;
297            }
298
299            yield Ok(StreamEvent::message_start(crate::types::content::Role::Assistant));
300            yield Ok(StreamEvent::content_block_start(0, None));
301
302            let body = match response.text().await {
303                Ok(b) => b,
304                Err(e) => {
305                    yield Err(StrandsError::NetworkError(e.to_string()));
306                    return;
307                }
308            };
309
310            let mut tool_used = false;
311            let mut finish_reason = "STOP";
312            let mut input_tokens = 0u64;
313            let mut output_tokens = 0u64;
314
315            for line in body.lines() {
316                let line = line.trim();
317
318                if line.is_empty() || line.starts_with(':') {
319                    continue;
320                }
321
322                if let Some(data) = line.strip_prefix("data: ") {
323                    if data.trim() == "[DONE]" {
324                        continue;
325                    }
326
327                    if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
328                        if let Some(usage) = parsed.get("usageMetadata") {
329                            if let Some(prompt_tokens) = usage.get("promptTokenCount").and_then(|v| v.as_u64()) {
330                                input_tokens = prompt_tokens;
331                            }
332                            if let Some(candidates_tokens) = usage.get("candidatesTokenCount").and_then(|v| v.as_u64()) {
333                                output_tokens = candidates_tokens;
334                            }
335                        }
336
337                        if let Some(candidates) = parsed.get("candidates").and_then(|c| c.as_array()) {
338                            for candidate in candidates {
339                                if let Some(reason) = candidate.get("finishReason").and_then(|r| r.as_str()) {
340                                    finish_reason = match reason {
341                                        "MAX_TOKENS" => "MAX_TOKENS",
342                                        "SAFETY" => "SAFETY",
343                                        "STOP" | _ => "STOP",
344                                    };
345                                }
346
347                                if let Some(content) = candidate.get("content") {
348                                    if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
349                                        for part in parts {
350                                            if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
351                                                let is_thought = part.get("thought").and_then(|t| t.as_bool()).unwrap_or(false);
352                                                if is_thought {
353                                                    yield Ok(StreamEvent::reasoning_delta(0, text));
354                                                } else {
355                                                    yield Ok(StreamEvent::text_delta(0, text));
356                                                }
357                                            }
358
359                                            if let Some(function_call) = part.get("functionCall") {
360                                                if let (Some(name), Some(args)) = (
361                                                    function_call.get("name").and_then(|n| n.as_str()),
362                                                    function_call.get("args"),
363                                                ) {
364                                                    tool_used = true;
365                                                    yield Ok(StreamEvent::tool_use_start(
366                                                        1,
367                                                        name,
368                                                        name,
369                                                    ));
370                                                    yield Ok(StreamEvent::tool_use_delta(
371                                                        1,
372                                                        &serde_json::to_string(args).unwrap_or_default(),
373                                                    ));
374                                                    yield Ok(StreamEvent::content_block_stop(1));
375                                                }
376                                            }
377                                        }
378                                    }
379                                }
380                            }
381                        }
382                    }
383                }
384            }
385
386            yield Ok(StreamEvent::content_block_stop(0));
387
388            let stop_reason = if tool_used {
389                crate::types::streaming::StopReason::ToolUse
390            } else {
391                match finish_reason {
392                    "MAX_TOKENS" => crate::types::streaming::StopReason::MaxTokens,
393                    _ => crate::types::streaming::StopReason::EndTurn,
394                }
395            };
396
397            yield Ok(StreamEvent::message_stop(stop_reason));
398
399            yield Ok(StreamEvent::metadata(
400                crate::types::streaming::Usage::new(input_tokens as u32, output_tokens as u32),
401                crate::types::streaming::Metrics::default(),
402            ));
403        })
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_gemini_config() {
413        let config = GeminiConfig::new("gemini-2.5-flash")
414            .with_api_key("test-key")
415            .with_param("temperature", serde_json::json!(0.7));
416
417        assert_eq!(config.model_id, "gemini-2.5-flash");
418        assert_eq!(config.api_key, Some("test-key".to_string()));
419        assert!(config.params.contains_key("temperature"));
420    }
421}
422