Skip to main content

stakpak_shared/models/integrations/
gemini.rs

1//! Gemini provider configuration and model definitions
2//!
3//! This module contains configuration types and model enums for Google Gemini.
4//! Request/response types for API communication are in `libs/ai/src/providers/gemini/`.
5
6use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
7use serde::{Deserialize, Serialize};
8
9/// Configuration for Gemini provider
10#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
11pub struct GeminiConfig {
12    pub api_endpoint: Option<String>,
13    pub api_key: Option<String>,
14}
15
16impl GeminiConfig {
17    /// Create config with API key
18    pub fn with_api_key(api_key: impl Into<String>) -> Self {
19        Self {
20            api_key: Some(api_key.into()),
21            api_endpoint: None,
22        }
23    }
24
25    /// Create config from ProviderAuth (only supports API key for Gemini)
26    pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
27        match auth {
28            crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
29            crate::models::auth::ProviderAuth::OAuth { .. } => None, // Gemini doesn't support OAuth in this impl
30        }
31    }
32
33    /// Merge with credentials from ProviderAuth, preserving existing endpoint
34    pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
35        match auth {
36            crate::models::auth::ProviderAuth::Api { key } => {
37                self.api_key = Some(key.clone());
38                Some(self)
39            }
40            crate::models::auth::ProviderAuth::OAuth { .. } => None,
41        }
42    }
43}
44
45/// Gemini model identifiers
46#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
47pub enum GeminiModel {
48    #[default]
49    #[serde(rename = "gemini-3-pro-preview")]
50    Gemini3Pro,
51    #[serde(rename = "gemini-3-flash-preview")]
52    Gemini3Flash,
53    #[serde(rename = "gemini-2.5-pro")]
54    Gemini25Pro,
55    #[serde(rename = "gemini-2.5-flash")]
56    Gemini25Flash,
57    #[serde(rename = "gemini-2.5-flash-lite")]
58    Gemini25FlashLite,
59}
60
61impl std::fmt::Display for GeminiModel {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
65            GeminiModel::Gemini3Flash => write!(f, "gemini-3-flash-preview"),
66            GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
67            GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
68            GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
69        }
70    }
71}
72
73impl GeminiModel {
74    pub fn from_string(s: &str) -> Result<Self, String> {
75        serde_json::from_value(serde_json::Value::String(s.to_string()))
76            .map_err(|_| "Failed to deserialize Gemini model".to_string())
77    }
78
79    /// Default smart model for Gemini
80    pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
81
82    /// Default eco model for Gemini
83    pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
84
85    /// Default recovery model for Gemini
86    pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
87
88    /// Get default smart model as string
89    pub fn default_smart_model() -> String {
90        Self::DEFAULT_SMART_MODEL.to_string()
91    }
92
93    /// Get default eco model as string
94    pub fn default_eco_model() -> String {
95        Self::DEFAULT_ECO_MODEL.to_string()
96    }
97
98    /// Get default recovery model as string
99    pub fn default_recovery_model() -> String {
100        Self::DEFAULT_RECOVERY_MODEL.to_string()
101    }
102}
103
104impl ContextAware for GeminiModel {
105    fn context_info(&self) -> ModelContextInfo {
106        match self {
107            GeminiModel::Gemini3Pro => ModelContextInfo {
108                max_tokens: 1_000_000,
109                pricing_tiers: vec![
110                    ContextPricingTier {
111                        label: "<200k tokens".to_string(),
112                        input_cost_per_million: 2.0,
113                        output_cost_per_million: 12.0,
114                        upper_bound: Some(200_000),
115                    },
116                    ContextPricingTier {
117                        label: ">200k tokens".to_string(),
118                        input_cost_per_million: 4.0,
119                        output_cost_per_million: 18.0,
120                        upper_bound: None,
121                    },
122                ],
123                approach_warning_threshold: 0.8,
124            },
125            GeminiModel::Gemini25Pro => ModelContextInfo {
126                max_tokens: 1_000_000,
127                pricing_tiers: vec![
128                    ContextPricingTier {
129                        label: "<200k tokens".to_string(),
130                        input_cost_per_million: 1.25,
131                        output_cost_per_million: 10.0,
132                        upper_bound: Some(200_000),
133                    },
134                    ContextPricingTier {
135                        label: ">200k tokens".to_string(),
136                        input_cost_per_million: 2.50,
137                        output_cost_per_million: 15.0,
138                        upper_bound: None,
139                    },
140                ],
141                approach_warning_threshold: 0.8,
142            },
143            GeminiModel::Gemini25Flash => ModelContextInfo {
144                max_tokens: 1_000_000,
145                pricing_tiers: vec![ContextPricingTier {
146                    label: "Standard".to_string(),
147                    input_cost_per_million: 0.30,
148                    output_cost_per_million: 2.50,
149                    upper_bound: None,
150                }],
151                approach_warning_threshold: 0.8,
152            },
153            GeminiModel::Gemini3Flash => ModelContextInfo {
154                max_tokens: 1_000_000,
155                pricing_tiers: vec![ContextPricingTier {
156                    label: "Standard".to_string(),
157                    input_cost_per_million: 0.50,
158                    output_cost_per_million: 3.0,
159                    upper_bound: None,
160                }],
161                approach_warning_threshold: 0.8,
162            },
163            GeminiModel::Gemini25FlashLite => ModelContextInfo {
164                max_tokens: 1_000_000,
165                pricing_tiers: vec![ContextPricingTier {
166                    label: "Standard".to_string(),
167                    input_cost_per_million: 0.1,
168                    output_cost_per_million: 0.4,
169                    upper_bound: None,
170                }],
171                approach_warning_threshold: 0.8,
172            },
173        }
174    }
175
176    fn model_name(&self) -> String {
177        match self {
178            GeminiModel::Gemini3Pro => "Gemini 3 Pro".to_string(),
179            GeminiModel::Gemini3Flash => "Gemini 3 Flash".to_string(),
180            GeminiModel::Gemini25Pro => "Gemini 2.5 Pro".to_string(),
181            GeminiModel::Gemini25Flash => "Gemini 2.5 Flash".to_string(),
182            GeminiModel::Gemini25FlashLite => "Gemini 2.5 Flash Lite".to_string(),
183        }
184    }
185}