stakpak_shared/models/integrations/
gemini.rs

1use crate::models::llm::{
2    LLMChoice, LLMCompletionResponse, LLMMessage, LLMMessageContent, LLMMessageTypedContent,
3    LLMTokenUsage, LLMTool,
4};
5use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
9pub struct GeminiConfig {
10    pub api_endpoint: Option<String>,
11    pub api_key: Option<String>,
12}
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
15pub enum GeminiModel {
16    #[default]
17    #[serde(rename = "gemini-3-pro-preview")]
18    Gemini3Pro,
19    #[serde(rename = "gemini-3-flash-preview")]
20    Gemini3Flash,
21    #[serde(rename = "gemini-2.5-pro")]
22    Gemini25Pro,
23    #[serde(rename = "gemini-2.5-flash")]
24    Gemini25Flash,
25    #[serde(rename = "gemini-2.5-flash-lite")]
26    Gemini25FlashLite,
27}
28
29impl std::fmt::Display for GeminiModel {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
33            GeminiModel::Gemini3Flash => write!(f, "gemini-3-flash-preview"),
34            GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
35            GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
36            GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
37        }
38    }
39}
40
41impl GeminiModel {
42    pub fn from_string(s: &str) -> Result<Self, String> {
43        serde_json::from_value(serde_json::Value::String(s.to_string()))
44            .map_err(|_| "Failed to deserialize Gemini model".to_string())
45    }
46
47    /// Default smart model for Gemini
48    pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
49
50    /// Default eco model for Gemini
51    pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
52
53    /// Default recovery model for Gemini
54    pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
55
56    /// Get default smart model as string
57    pub fn default_smart_model() -> String {
58        Self::DEFAULT_SMART_MODEL.to_string()
59    }
60
61    /// Get default eco model as string
62    pub fn default_eco_model() -> String {
63        Self::DEFAULT_ECO_MODEL.to_string()
64    }
65
66    /// Get default recovery model as string
67    pub fn default_recovery_model() -> String {
68        Self::DEFAULT_RECOVERY_MODEL.to_string()
69    }
70}
71
72impl ContextAware for GeminiModel {
73    fn context_info(&self) -> ModelContextInfo {
74        match self {
75            GeminiModel::Gemini3Pro => ModelContextInfo {
76                max_tokens: 1_000_000,
77                pricing_tiers: vec![
78                    ContextPricingTier {
79                        label: "<200k tokens".to_string(),
80                        input_cost_per_million: 2.0,
81                        output_cost_per_million: 12.0,
82                        upper_bound: Some(200_000),
83                    },
84                    ContextPricingTier {
85                        label: ">200k tokens".to_string(),
86                        input_cost_per_million: 4.0,
87                        output_cost_per_million: 18.0,
88                        upper_bound: None,
89                    },
90                ],
91                approach_warning_threshold: 0.8,
92            },
93            GeminiModel::Gemini25Pro => ModelContextInfo {
94                max_tokens: 1_000_000,
95                pricing_tiers: vec![
96                    ContextPricingTier {
97                        label: "<200k tokens".to_string(),
98                        input_cost_per_million: 1.25,
99                        output_cost_per_million: 10.0,
100                        upper_bound: Some(200_000),
101                    },
102                    ContextPricingTier {
103                        label: ">200k tokens".to_string(),
104                        input_cost_per_million: 2.50,
105                        output_cost_per_million: 15.0,
106                        upper_bound: None,
107                    },
108                ],
109                approach_warning_threshold: 0.8,
110            },
111            GeminiModel::Gemini25Flash => ModelContextInfo {
112                max_tokens: 1_000_000,
113                pricing_tiers: vec![ContextPricingTier {
114                    label: "Standard".to_string(),
115                    input_cost_per_million: 0.30,
116                    output_cost_per_million: 2.50,
117                    upper_bound: None,
118                }],
119                approach_warning_threshold: 0.8,
120            },
121            GeminiModel::Gemini3Flash => ModelContextInfo {
122                max_tokens: 1_000_000,
123                pricing_tiers: vec![ContextPricingTier {
124                    label: "Standard".to_string(),
125                    input_cost_per_million: 0.50,
126                    output_cost_per_million: 3.0,
127                    upper_bound: None,
128                }],
129                approach_warning_threshold: 0.8,
130            },
131            GeminiModel::Gemini25FlashLite => ModelContextInfo {
132                max_tokens: 1_000_000,
133                pricing_tiers: vec![ContextPricingTier {
134                    label: "Standard".to_string(),
135                    input_cost_per_million: 0.1,
136                    output_cost_per_million: 0.4,
137                    upper_bound: None,
138                }],
139                approach_warning_threshold: 0.8,
140            },
141        }
142    }
143
144    fn model_name(&self) -> String {
145        match self {
146            GeminiModel::Gemini3Pro => "Gemini 3 Pro".to_string(),
147            GeminiModel::Gemini3Flash => "Gemini 3 Flash".to_string(),
148            GeminiModel::Gemini25Pro => "Gemini 2.5 Pro".to_string(),
149            GeminiModel::Gemini25Flash => "Gemini 2.5 Flash".to_string(),
150            GeminiModel::Gemini25FlashLite => "Gemini 2.5 Flash Lite".to_string(),
151        }
152    }
153}
154
155#[derive(Serialize, Deserialize, Debug)]
156pub struct GeminiInput {
157    pub model: GeminiModel,
158    pub messages: Vec<LLMMessage>,
159    pub max_tokens: u32,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub tools: Option<Vec<LLMTool>>,
162}
163
164#[derive(Serialize, Deserialize, Debug)]
165#[serde(rename_all = "camelCase")]
166pub struct GeminiRequest {
167    pub contents: Vec<GeminiContent>,
168
169    #[serde(skip_serializing_if = "Option::is_none")]
170    pub tools: Option<Vec<GeminiTool>>,
171
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub system_instruction: Option<GeminiSystemInstruction>, // checked
174
175    #[serde(skip_serializing_if = "Option::is_none")]
176    pub generation_config: Option<GeminiGenerationConfig>, // checked
177}
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub enum GeminiRole {
181    #[serde(rename = "user")]
182    User,
183    #[serde(rename = "model")]
184    Model,
185}
186
187impl std::fmt::Display for GeminiRole {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            GeminiRole::User => write!(f, "user"),
191            GeminiRole::Model => write!(f, "model"),
192        }
193    }
194}
195
196impl GeminiRole {
197    pub fn from_string(s: &str) -> Result<Self, String> {
198        serde_json::from_value(serde_json::Value::String(s.to_string()))
199            .map_err(|_| "Failed to deserialize Gemini role".to_string())
200    }
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone)]
204pub struct GeminiContent {
205    pub role: GeminiRole,
206    #[serde(default)]
207    pub parts: Vec<GeminiPart>,
208}
209
210#[derive(Serialize, Deserialize, Debug, Clone)]
211#[serde(untagged)]
212pub enum GeminiPart {
213    Text {
214        text: String,
215    },
216    FunctionCall {
217        #[serde(rename = "functionCall")]
218        function_call: GeminiFunctionCall,
219    },
220    FunctionResponse {
221        #[serde(rename = "functionResponse")]
222        function_response: GeminiFunctionResponse,
223    },
224    InlineData {
225        #[serde(rename = "inlineData")]
226        inline_data: GeminiInlineData,
227    },
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231pub struct GeminiFunctionCall {
232    #[serde(default)]
233    pub id: Option<String>,
234    pub name: String,
235    pub args: serde_json::Value,
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct GeminiFunctionResponse {
240    pub id: String,
241    pub name: String,
242    pub response: serde_json::Value,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246pub struct GeminiInlineData {
247    pub mime_type: String,
248    pub data: String,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252pub struct GeminiSystemInstruction {
253    pub parts: Vec<GeminiPart>, // checked
254}
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct GeminiTool {
258    pub function_declarations: Vec<GeminiFunctionDeclaration>,
259}
260
261#[derive(Serialize, Deserialize, Debug)]
262pub struct GeminiFunctionDeclaration {
263    pub name: String,
264    pub description: String,
265    pub parameters_json_schema: Option<serde_json::Value>,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct GeminiGenerationConfig {
270    pub max_output_tokens: Option<u32>,
271    pub temperature: Option<f32>,
272    pub candidate_count: Option<u32>,
273}
274
275// Gemini API Response Structs
276
277#[derive(Serialize, Deserialize, Debug, Clone)]
278#[serde(rename_all = "camelCase")]
279pub struct GeminiResponse {
280    pub candidates: Option<Vec<GeminiCandidate>>,
281    pub usage_metadata: Option<GeminiUsageMetadata>,
282    pub model_version: Option<String>,
283    pub response_id: Option<String>,
284}
285
286#[derive(Serialize, Deserialize, Debug, Clone)]
287#[serde(rename_all = "camelCase")]
288pub struct GeminiCandidate {
289    pub content: Option<GeminiContent>,
290    pub finish_reason: Option<String>,
291    pub index: Option<u32>,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295#[serde(rename_all = "camelCase")]
296pub struct GeminiUsageMetadata {
297    pub prompt_token_count: Option<u32>,
298    pub cached_content_token_count: Option<u32>,
299    pub candidates_token_count: Option<u32>,
300    pub tool_use_prompt_token_count: Option<u32>,
301    pub thoughts_token_count: Option<u32>,
302    pub total_token_count: Option<u32>,
303}
304
305impl From<LLMMessage> for GeminiContent {
306    fn from(message: LLMMessage) -> Self {
307        let role = match message.role.as_str() {
308            "assistant" | "model" => GeminiRole::Model,
309            "user" | "tool" => GeminiRole::User,
310            _ => GeminiRole::User,
311        };
312
313        let parts = match message.content {
314            LLMMessageContent::String(text) => vec![GeminiPart::Text { text }],
315            LLMMessageContent::List(items) => items
316                .into_iter()
317                .map(|item| match item {
318                    LLMMessageTypedContent::Text { text } => GeminiPart::Text { text },
319
320                    LLMMessageTypedContent::ToolCall { id, name, args } => {
321                        GeminiPart::FunctionCall {
322                            function_call: GeminiFunctionCall {
323                                id: Some(id),
324                                name,
325                                args,
326                            },
327                        }
328                    }
329
330                    LLMMessageTypedContent::ToolResult { content, .. } => {
331                        GeminiPart::Text { text: content }
332                    }
333
334                    LLMMessageTypedContent::Image { source } => GeminiPart::InlineData {
335                        inline_data: GeminiInlineData {
336                            mime_type: source.media_type,
337                            data: source.data,
338                        },
339                    },
340                })
341                .collect(),
342        };
343
344        GeminiContent { role, parts }
345    }
346}
347
348// Conversion from GeminiContent to LLMMessage
349impl From<GeminiContent> for LLMMessage {
350    fn from(content: GeminiContent) -> Self {
351        let role = content.role.to_string();
352        let mut message_content = Vec::new();
353
354        for part in content.parts {
355            match part {
356                GeminiPart::Text { text } => {
357                    message_content.push(LLMMessageTypedContent::Text { text });
358                }
359                GeminiPart::FunctionCall { function_call } => {
360                    message_content.push(LLMMessageTypedContent::ToolCall {
361                        id: function_call.id.unwrap_or_else(|| "".to_string()),
362                        name: function_call.name,
363                        args: function_call.args,
364                    });
365                }
366                GeminiPart::FunctionResponse { function_response } => {
367                    message_content.push(LLMMessageTypedContent::ToolResult {
368                        tool_use_id: function_response.id,
369                        content: function_response.response.to_string(),
370                    });
371                }
372                //TODO: Add Image support
373                _ => {}
374            }
375        }
376
377        let content = if message_content.is_empty() {
378            LLMMessageContent::String(String::new())
379        } else if message_content.len() == 1 {
380            match &message_content[0] {
381                LLMMessageTypedContent::Text { text } => LLMMessageContent::String(text.clone()),
382                _ => LLMMessageContent::List(message_content),
383            }
384        } else {
385            LLMMessageContent::List(message_content)
386        };
387
388        LLMMessage { role, content }
389    }
390}
391
392impl From<LLMTool> for GeminiFunctionDeclaration {
393    fn from(tool: LLMTool) -> Self {
394        GeminiFunctionDeclaration {
395            name: tool.name,
396            description: tool.description,
397            parameters_json_schema: Some(tool.input_schema),
398        }
399    }
400}
401
402impl From<Vec<LLMTool>> for GeminiTool {
403    fn from(tools: Vec<LLMTool>) -> Self {
404        GeminiTool {
405            function_declarations: tools.into_iter().map(|t| t.into()).collect(),
406        }
407    }
408}
409
410impl From<GeminiResponse> for LLMCompletionResponse {
411    fn from(response: GeminiResponse) -> Self {
412        let usage = response.usage_metadata.map(|u| LLMTokenUsage {
413            prompt_tokens: u.prompt_token_count.unwrap_or(0),
414            completion_tokens: u.candidates_token_count.unwrap_or(0),
415            total_tokens: u.total_token_count.unwrap_or(0),
416            prompt_tokens_details: None,
417        });
418
419        let choices = response
420            .candidates
421            .unwrap_or_default()
422            .into_iter()
423            .enumerate()
424            .map(|(index, candidate)| {
425                let message = candidate
426                    .content
427                    .map(|c| c.into())
428                    .unwrap_or_else(|| LLMMessage {
429                        role: "model".to_string(),
430                        content: LLMMessageContent::String(String::new()),
431                    });
432
433                let has_tool_calls = match &message.content {
434                    LLMMessageContent::List(items) => items
435                        .iter()
436                        .any(|item| matches!(item, LLMMessageTypedContent::ToolCall { .. })),
437                    _ => false,
438                };
439
440                let finish_reason = if has_tool_calls {
441                    Some("tool_calls".to_string())
442                } else {
443                    candidate.finish_reason.map(|s| s.to_lowercase())
444                };
445
446                LLMChoice {
447                    finish_reason,
448                    index: index as u32,
449                    message,
450                }
451            })
452            .collect();
453
454        LLMCompletionResponse {
455            // Use model_version from the response, with fallback
456            model: response
457                .model_version
458                .unwrap_or_else(|| "gemini".to_string()),
459            object: "chat.completion".to_string(),
460            choices,
461            created: chrono::Utc::now().timestamp_millis() as u64,
462            usage,
463            id: response
464                .response_id
465                .unwrap_or_else(|| "unknown".to_string()),
466        }
467    }
468}