1use std::collections::HashMap;
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9
10use crate::models::{Model, ModelConfig, StreamEventStream};
11use crate::types::content::{Message, SystemContentBlock};
12use crate::types::errors::StrandsError;
13use crate::types::streaming::StreamEvent;
14use crate::types::tools::{ToolChoice, ToolSpec};
15
16#[derive(Debug, Clone, Default)]
18pub struct GeminiConfig {
19 pub model_id: String,
21 pub params: HashMap<String, serde_json::Value>,
23 pub api_key: Option<String>,
25 pub base_url: Option<String>,
27}
28
29impl GeminiConfig {
30 pub fn new(model_id: impl Into<String>) -> Self {
31 Self {
32 model_id: model_id.into(),
33 ..Default::default()
34 }
35 }
36
37 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
38 self.api_key = Some(api_key.into());
39 self
40 }
41
42 pub fn with_param(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
43 self.params.insert(key.into(), value);
44 self
45 }
46}
47
48#[derive(Debug, Serialize)]
50#[serde(rename_all = "camelCase")]
51struct GeminiRequest {
52 contents: Vec<GeminiContent>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 system_instruction: Option<GeminiContent>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 tools: Option<Vec<GeminiTool>>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 generation_config: Option<serde_json::Value>,
59}
60
61#[derive(Debug, Serialize, Deserialize)]
62struct GeminiContent {
63 role: String,
64 parts: Vec<GeminiPart>,
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68#[serde(untagged)]
69enum GeminiPart {
70 Text { text: String },
71 FunctionCall { function_call: GeminiFunctionCall },
72 FunctionResponse { function_response: GeminiFunctionResponse },
73}
74
75#[derive(Debug, Serialize, Deserialize)]
76struct GeminiFunctionCall {
77 name: String,
78 args: serde_json::Value,
79}
80
81#[derive(Debug, Serialize, Deserialize)]
82struct GeminiFunctionResponse {
83 name: String,
84 response: serde_json::Value,
85}
86
87#[derive(Debug, Serialize)]
88struct GeminiTool {
89 function_declarations: Vec<GeminiFunctionDeclaration>,
90}
91
92#[derive(Debug, Serialize)]
93struct GeminiFunctionDeclaration {
94 name: String,
95 description: String,
96 parameters: serde_json::Value,
97}
98
99pub struct GeminiModel {
101 config: ModelConfig,
102 gemini_config: GeminiConfig,
103 client: reqwest::Client,
104}
105
106impl GeminiModel {
107 const DEFAULT_BASE_URL: &'static str = "https://generativelanguage.googleapis.com/v1beta";
108
109 pub fn new(config: GeminiConfig) -> Self {
110 let model_config = ModelConfig::new(&config.model_id);
111
112 Self {
113 config: model_config,
114 gemini_config: config,
115 client: reqwest::Client::new(),
116 }
117 }
118
119 fn base_url(&self) -> &str {
120 self.gemini_config
121 .base_url
122 .as_deref()
123 .unwrap_or(Self::DEFAULT_BASE_URL)
124 }
125
126 fn api_key(&self) -> Result<&str, StrandsError> {
127 self.gemini_config
128 .api_key
129 .as_deref()
130 .or_else(|| std::env::var("GOOGLE_API_KEY").ok().as_deref().map(|_| ""))
131 .ok_or_else(|| StrandsError::ConfigurationError {
132 message: "Gemini API key not configured. Set GOOGLE_API_KEY or provide api_key".into(),
133 })
134 }
135
136 fn convert_messages(&self, messages: &[Message]) -> Vec<GeminiContent> {
137 messages
138 .iter()
139 .map(|msg| {
140 let role = match msg.role {
141 crate::types::content::Role::User => "user",
142 crate::types::content::Role::Assistant => "model",
143 };
144
145 let parts: Vec<GeminiPart> = msg
146 .content
147 .iter()
148 .filter_map(|block| {
149 if let Some(text) = &block.text {
150 Some(GeminiPart::Text { text: text.clone() })
151 } else if let Some(tool_use) = &block.tool_use {
152 Some(GeminiPart::FunctionCall {
153 function_call: GeminiFunctionCall {
154 name: tool_use.name.clone(),
155 args: tool_use.input.clone(),
156 },
157 })
158 } else if let Some(tool_result) = &block.tool_result {
159 Some(GeminiPart::FunctionResponse {
160 function_response: GeminiFunctionResponse {
161 name: tool_result.tool_use_id.clone(),
162 response: serde_json::json!({
163 "content": tool_result.content
164 }),
165 },
166 })
167 } else {
168 None
169 }
170 })
171 .collect();
172
173 GeminiContent {
174 role: role.to_string(),
175 parts,
176 }
177 })
178 .collect()
179 }
180
181 fn convert_tools(&self, tool_specs: &[ToolSpec]) -> Vec<GeminiTool> {
182 let declarations: Vec<GeminiFunctionDeclaration> = tool_specs
183 .iter()
184 .map(|spec| GeminiFunctionDeclaration {
185 name: spec.name.clone(),
186 description: spec.description.clone(),
187 parameters: serde_json::to_value(&spec.input_schema).unwrap_or_default(),
188 })
189 .collect();
190
191 vec![GeminiTool {
192 function_declarations: declarations,
193 }]
194 }
195}
196
197#[async_trait]
198impl Model for GeminiModel {
199 fn config(&self) -> &ModelConfig {
200 &self.config
201 }
202
203 fn update_config(&mut self, config: ModelConfig) {
204 self.config = config;
205 }
206
207 fn stream<'a>(
208 &'a self,
209 messages: &'a [Message],
210 tool_specs: Option<&'a [ToolSpec]>,
211 system_prompt: Option<&'a str>,
212 _tool_choice: Option<ToolChoice>,
213 _system_prompt_content: Option<&'a [SystemContentBlock]>,
214 ) -> StreamEventStream<'a> {
215 let messages = messages.to_vec();
216 let tool_specs = tool_specs.map(|t| t.to_vec());
217 let system_prompt = system_prompt.map(|s| s.to_string());
218
219 Box::pin(async_stream::stream! {
220 let api_key = match self.api_key() {
221 Ok(key) => key.to_string(),
222 Err(e) => {
223 yield Err(e);
224 return;
225 }
226 };
227
228 let api_key = if api_key.is_empty() {
229 match std::env::var("GOOGLE_API_KEY") {
230 Ok(key) => key,
231 Err(_) => {
232 yield Err(StrandsError::ConfigurationError {
233 message: "GOOGLE_API_KEY not set".into(),
234 });
235 return;
236 }
237 }
238 } else {
239 api_key
240 };
241
242 let contents = self.convert_messages(&messages);
243
244 let system_instruction = system_prompt.map(|prompt| GeminiContent {
245 role: "user".to_string(),
246 parts: vec![GeminiPart::Text { text: prompt }],
247 });
248
249 let tools = tool_specs.as_ref().map(|specs| self.convert_tools(specs));
250
251 let request = GeminiRequest {
252 contents,
253 system_instruction,
254 tools,
255 generation_config: if self.gemini_config.params.is_empty() {
256 None
257 } else {
258 Some(serde_json::to_value(&self.gemini_config.params).unwrap_or_default())
259 },
260 };
261
262 let url = format!(
263 "{}/models/{}:streamGenerateContent?key={}&alt=sse",
264 self.base_url(),
265 self.config.model_id,
266 api_key
267 );
268
269 let response = match self.client
270 .post(&url)
271 .json(&request)
272 .send()
273 .await
274 {
275 Ok(resp) => resp,
276 Err(e) => {
277 yield Err(StrandsError::NetworkError(e.to_string()));
278 return;
279 }
280 };
281
282 if !response.status().is_success() {
283 let status = response.status();
284 let body = response.text().await.unwrap_or_default();
285
286 if status.as_u16() == 429 {
287 yield Err(StrandsError::ModelThrottled {
288 message: "Gemini rate limit exceeded".into(),
289 });
290 } else {
291 yield Err(StrandsError::ModelError {
292 message: format!("Gemini API error {}: {}", status, body),
293 source: None,
294 });
295 }
296 return;
297 }
298
299 yield Ok(StreamEvent::message_start(crate::types::content::Role::Assistant));
300 yield Ok(StreamEvent::content_block_start(0, None));
301
302 let body = match response.text().await {
303 Ok(b) => b,
304 Err(e) => {
305 yield Err(StrandsError::NetworkError(e.to_string()));
306 return;
307 }
308 };
309
310 let mut tool_used = false;
311 let mut finish_reason = "STOP";
312 let mut input_tokens = 0u64;
313 let mut output_tokens = 0u64;
314
315 for line in body.lines() {
316 let line = line.trim();
317
318 if line.is_empty() || line.starts_with(':') {
319 continue;
320 }
321
322 if let Some(data) = line.strip_prefix("data: ") {
323 if data.trim() == "[DONE]" {
324 continue;
325 }
326
327 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
328 if let Some(usage) = parsed.get("usageMetadata") {
329 if let Some(prompt_tokens) = usage.get("promptTokenCount").and_then(|v| v.as_u64()) {
330 input_tokens = prompt_tokens;
331 }
332 if let Some(candidates_tokens) = usage.get("candidatesTokenCount").and_then(|v| v.as_u64()) {
333 output_tokens = candidates_tokens;
334 }
335 }
336
337 if let Some(candidates) = parsed.get("candidates").and_then(|c| c.as_array()) {
338 for candidate in candidates {
339 if let Some(reason) = candidate.get("finishReason").and_then(|r| r.as_str()) {
340 finish_reason = match reason {
341 "MAX_TOKENS" => "MAX_TOKENS",
342 "SAFETY" => "SAFETY",
343 "STOP" | _ => "STOP",
344 };
345 }
346
347 if let Some(content) = candidate.get("content") {
348 if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
349 for part in parts {
350 if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
351 let is_thought = part.get("thought").and_then(|t| t.as_bool()).unwrap_or(false);
352 if is_thought {
353 yield Ok(StreamEvent::reasoning_delta(0, text));
354 } else {
355 yield Ok(StreamEvent::text_delta(0, text));
356 }
357 }
358
359 if let Some(function_call) = part.get("functionCall") {
360 if let (Some(name), Some(args)) = (
361 function_call.get("name").and_then(|n| n.as_str()),
362 function_call.get("args"),
363 ) {
364 tool_used = true;
365 yield Ok(StreamEvent::tool_use_start(
366 1,
367 name,
368 name,
369 ));
370 yield Ok(StreamEvent::tool_use_delta(
371 1,
372 &serde_json::to_string(args).unwrap_or_default(),
373 ));
374 yield Ok(StreamEvent::content_block_stop(1));
375 }
376 }
377 }
378 }
379 }
380 }
381 }
382 }
383 }
384 }
385
386 yield Ok(StreamEvent::content_block_stop(0));
387
388 let stop_reason = if tool_used {
389 crate::types::streaming::StopReason::ToolUse
390 } else {
391 match finish_reason {
392 "MAX_TOKENS" => crate::types::streaming::StopReason::MaxTokens,
393 _ => crate::types::streaming::StopReason::EndTurn,
394 }
395 };
396
397 yield Ok(StreamEvent::message_stop(stop_reason));
398
399 yield Ok(StreamEvent::metadata(
400 crate::types::streaming::Usage::new(input_tokens as u32, output_tokens as u32),
401 crate::types::streaming::Metrics::default(),
402 ));
403 })
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_gemini_config() {
413 let config = GeminiConfig::new("gemini-2.5-flash")
414 .with_api_key("test-key")
415 .with_param("temperature", serde_json::json!(0.7));
416
417 assert_eq!(config.model_id, "gemini-2.5-flash");
418 assert_eq!(config.api_key, Some("test-key".to_string()));
419 assert!(config.params.contains_key("temperature"));
420 }
421}
422