1use crate::models::llm::{
2 LLMChoice, LLMCompletionResponse, LLMMessage, LLMMessageContent, LLMMessageTypedContent,
3 LLMTokenUsage, LLMTool,
4};
5use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
6use serde::{Deserialize, Serialize};
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
9pub struct GeminiConfig {
10 pub api_endpoint: Option<String>,
11 pub api_key: Option<String>,
12}
13
14#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
15pub enum GeminiModel {
16 #[default]
17 #[serde(rename = "gemini-3-pro-preview")]
18 Gemini3Pro,
19 #[serde(rename = "gemini-3-flash-preview")]
20 Gemini3Flash,
21 #[serde(rename = "gemini-2.5-pro")]
22 Gemini25Pro,
23 #[serde(rename = "gemini-2.5-flash")]
24 Gemini25Flash,
25 #[serde(rename = "gemini-2.5-flash-lite")]
26 Gemini25FlashLite,
27}
28
29impl std::fmt::Display for GeminiModel {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 GeminiModel::Gemini3Pro => write!(f, "gemini-3-pro-preview"),
33 GeminiModel::Gemini3Flash => write!(f, "gemini-3-flash-preview"),
34 GeminiModel::Gemini25Pro => write!(f, "gemini-2.5-pro"),
35 GeminiModel::Gemini25Flash => write!(f, "gemini-2.5-flash"),
36 GeminiModel::Gemini25FlashLite => write!(f, "gemini-2.5-flash-lite"),
37 }
38 }
39}
40
41impl GeminiModel {
42 pub fn from_string(s: &str) -> Result<Self, String> {
43 serde_json::from_value(serde_json::Value::String(s.to_string()))
44 .map_err(|_| "Failed to deserialize Gemini model".to_string())
45 }
46
47 pub const DEFAULT_SMART_MODEL: GeminiModel = GeminiModel::Gemini3Pro;
49
50 pub const DEFAULT_ECO_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
52
53 pub const DEFAULT_RECOVERY_MODEL: GeminiModel = GeminiModel::Gemini3Flash;
55
56 pub fn default_smart_model() -> String {
58 Self::DEFAULT_SMART_MODEL.to_string()
59 }
60
61 pub fn default_eco_model() -> String {
63 Self::DEFAULT_ECO_MODEL.to_string()
64 }
65
66 pub fn default_recovery_model() -> String {
68 Self::DEFAULT_RECOVERY_MODEL.to_string()
69 }
70}
71
72impl ContextAware for GeminiModel {
73 fn context_info(&self) -> ModelContextInfo {
74 match self {
75 GeminiModel::Gemini3Pro => ModelContextInfo {
76 max_tokens: 1_000_000,
77 pricing_tiers: vec![
78 ContextPricingTier {
79 label: "<200k tokens".to_string(),
80 input_cost_per_million: 2.0,
81 output_cost_per_million: 12.0,
82 upper_bound: Some(200_000),
83 },
84 ContextPricingTier {
85 label: ">200k tokens".to_string(),
86 input_cost_per_million: 4.0,
87 output_cost_per_million: 18.0,
88 upper_bound: None,
89 },
90 ],
91 approach_warning_threshold: 0.8,
92 },
93 GeminiModel::Gemini25Pro => ModelContextInfo {
94 max_tokens: 1_000_000,
95 pricing_tiers: vec![
96 ContextPricingTier {
97 label: "<200k tokens".to_string(),
98 input_cost_per_million: 1.25,
99 output_cost_per_million: 10.0,
100 upper_bound: Some(200_000),
101 },
102 ContextPricingTier {
103 label: ">200k tokens".to_string(),
104 input_cost_per_million: 2.50,
105 output_cost_per_million: 15.0,
106 upper_bound: None,
107 },
108 ],
109 approach_warning_threshold: 0.8,
110 },
111 GeminiModel::Gemini25Flash => ModelContextInfo {
112 max_tokens: 1_000_000,
113 pricing_tiers: vec![ContextPricingTier {
114 label: "Standard".to_string(),
115 input_cost_per_million: 0.30,
116 output_cost_per_million: 2.50,
117 upper_bound: None,
118 }],
119 approach_warning_threshold: 0.8,
120 },
121 GeminiModel::Gemini3Flash => ModelContextInfo {
122 max_tokens: 1_000_000,
123 pricing_tiers: vec![ContextPricingTier {
124 label: "Standard".to_string(),
125 input_cost_per_million: 0.50,
126 output_cost_per_million: 3.0,
127 upper_bound: None,
128 }],
129 approach_warning_threshold: 0.8,
130 },
131 GeminiModel::Gemini25FlashLite => ModelContextInfo {
132 max_tokens: 1_000_000,
133 pricing_tiers: vec![ContextPricingTier {
134 label: "Standard".to_string(),
135 input_cost_per_million: 0.1,
136 output_cost_per_million: 0.4,
137 upper_bound: None,
138 }],
139 approach_warning_threshold: 0.8,
140 },
141 }
142 }
143
144 fn model_name(&self) -> String {
145 match self {
146 GeminiModel::Gemini3Pro => "Gemini 3 Pro".to_string(),
147 GeminiModel::Gemini3Flash => "Gemini 3 Flash".to_string(),
148 GeminiModel::Gemini25Pro => "Gemini 2.5 Pro".to_string(),
149 GeminiModel::Gemini25Flash => "Gemini 2.5 Flash".to_string(),
150 GeminiModel::Gemini25FlashLite => "Gemini 2.5 Flash Lite".to_string(),
151 }
152 }
153}
154
155#[derive(Serialize, Deserialize, Debug)]
156pub struct GeminiInput {
157 pub model: GeminiModel,
158 pub messages: Vec<LLMMessage>,
159 pub max_tokens: u32,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub tools: Option<Vec<LLMTool>>,
162}
163
164#[derive(Serialize, Deserialize, Debug)]
165#[serde(rename_all = "camelCase")]
166pub struct GeminiRequest {
167 pub contents: Vec<GeminiContent>,
168
169 #[serde(skip_serializing_if = "Option::is_none")]
170 pub tools: Option<Vec<GeminiTool>>,
171
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub system_instruction: Option<GeminiSystemInstruction>, #[serde(skip_serializing_if = "Option::is_none")]
176 pub generation_config: Option<GeminiGenerationConfig>, }
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub enum GeminiRole {
181 #[serde(rename = "user")]
182 User,
183 #[serde(rename = "model")]
184 Model,
185}
186
187impl std::fmt::Display for GeminiRole {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 GeminiRole::User => write!(f, "user"),
191 GeminiRole::Model => write!(f, "model"),
192 }
193 }
194}
195
196impl GeminiRole {
197 pub fn from_string(s: &str) -> Result<Self, String> {
198 serde_json::from_value(serde_json::Value::String(s.to_string()))
199 .map_err(|_| "Failed to deserialize Gemini role".to_string())
200 }
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone)]
204pub struct GeminiContent {
205 pub role: GeminiRole,
206 #[serde(default)]
207 pub parts: Vec<GeminiPart>,
208}
209
210#[derive(Serialize, Deserialize, Debug, Clone)]
211#[serde(untagged)]
212pub enum GeminiPart {
213 Text {
214 text: String,
215 },
216 FunctionCall {
217 #[serde(rename = "functionCall")]
218 function_call: GeminiFunctionCall,
219 },
220 FunctionResponse {
221 #[serde(rename = "functionResponse")]
222 function_response: GeminiFunctionResponse,
223 },
224 InlineData {
225 #[serde(rename = "inlineData")]
226 inline_data: GeminiInlineData,
227 },
228}
229
230#[derive(Serialize, Deserialize, Debug, Clone)]
231pub struct GeminiFunctionCall {
232 #[serde(default)]
233 pub id: Option<String>,
234 pub name: String,
235 pub args: serde_json::Value,
236}
237
238#[derive(Serialize, Deserialize, Debug, Clone)]
239pub struct GeminiFunctionResponse {
240 pub id: String,
241 pub name: String,
242 pub response: serde_json::Value,
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246pub struct GeminiInlineData {
247 pub mime_type: String,
248 pub data: String,
249}
250
251#[derive(Serialize, Deserialize, Debug)]
252pub struct GeminiSystemInstruction {
253 pub parts: Vec<GeminiPart>, }
255
256#[derive(Serialize, Deserialize, Debug)]
257pub struct GeminiTool {
258 pub function_declarations: Vec<GeminiFunctionDeclaration>,
259}
260
261#[derive(Serialize, Deserialize, Debug)]
262pub struct GeminiFunctionDeclaration {
263 pub name: String,
264 pub description: String,
265 pub parameters_json_schema: Option<serde_json::Value>,
266}
267
268#[derive(Serialize, Deserialize, Debug)]
269pub struct GeminiGenerationConfig {
270 pub max_output_tokens: Option<u32>,
271 pub temperature: Option<f32>,
272 pub candidate_count: Option<u32>,
273}
274
275#[derive(Serialize, Deserialize, Debug, Clone)]
278#[serde(rename_all = "camelCase")]
279pub struct GeminiResponse {
280 pub candidates: Option<Vec<GeminiCandidate>>,
281 pub usage_metadata: Option<GeminiUsageMetadata>,
282 pub model_version: Option<String>,
283 pub response_id: Option<String>,
284}
285
286#[derive(Serialize, Deserialize, Debug, Clone)]
287#[serde(rename_all = "camelCase")]
288pub struct GeminiCandidate {
289 pub content: Option<GeminiContent>,
290 pub finish_reason: Option<String>,
291 pub index: Option<u32>,
292}
293
294#[derive(Serialize, Deserialize, Debug, Clone)]
295#[serde(rename_all = "camelCase")]
296pub struct GeminiUsageMetadata {
297 pub prompt_token_count: Option<u32>,
298 pub cached_content_token_count: Option<u32>,
299 pub candidates_token_count: Option<u32>,
300 pub tool_use_prompt_token_count: Option<u32>,
301 pub thoughts_token_count: Option<u32>,
302 pub total_token_count: Option<u32>,
303}
304
305impl From<LLMMessage> for GeminiContent {
306 fn from(message: LLMMessage) -> Self {
307 let role = match message.role.as_str() {
308 "assistant" | "model" => GeminiRole::Model,
309 "user" | "tool" => GeminiRole::User,
310 _ => GeminiRole::User,
311 };
312
313 let parts = match message.content {
314 LLMMessageContent::String(text) => vec![GeminiPart::Text { text }],
315 LLMMessageContent::List(items) => items
316 .into_iter()
317 .map(|item| match item {
318 LLMMessageTypedContent::Text { text } => GeminiPart::Text { text },
319
320 LLMMessageTypedContent::ToolCall { id, name, args } => {
321 GeminiPart::FunctionCall {
322 function_call: GeminiFunctionCall {
323 id: Some(id),
324 name,
325 args,
326 },
327 }
328 }
329
330 LLMMessageTypedContent::ToolResult { content, .. } => {
331 GeminiPart::Text { text: content }
332 }
333
334 LLMMessageTypedContent::Image { source } => GeminiPart::InlineData {
335 inline_data: GeminiInlineData {
336 mime_type: source.media_type,
337 data: source.data,
338 },
339 },
340 })
341 .collect(),
342 };
343
344 GeminiContent { role, parts }
345 }
346}
347
348impl From<GeminiContent> for LLMMessage {
350 fn from(content: GeminiContent) -> Self {
351 let role = content.role.to_string();
352 let mut message_content = Vec::new();
353
354 for part in content.parts {
355 match part {
356 GeminiPart::Text { text } => {
357 message_content.push(LLMMessageTypedContent::Text { text });
358 }
359 GeminiPart::FunctionCall { function_call } => {
360 message_content.push(LLMMessageTypedContent::ToolCall {
361 id: function_call.id.unwrap_or_else(|| "".to_string()),
362 name: function_call.name,
363 args: function_call.args,
364 });
365 }
366 GeminiPart::FunctionResponse { function_response } => {
367 message_content.push(LLMMessageTypedContent::ToolResult {
368 tool_use_id: function_response.id,
369 content: function_response.response.to_string(),
370 });
371 }
372 _ => {}
374 }
375 }
376
377 let content = if message_content.is_empty() {
378 LLMMessageContent::String(String::new())
379 } else if message_content.len() == 1 {
380 match &message_content[0] {
381 LLMMessageTypedContent::Text { text } => LLMMessageContent::String(text.clone()),
382 _ => LLMMessageContent::List(message_content),
383 }
384 } else {
385 LLMMessageContent::List(message_content)
386 };
387
388 LLMMessage { role, content }
389 }
390}
391
392impl From<LLMTool> for GeminiFunctionDeclaration {
393 fn from(tool: LLMTool) -> Self {
394 GeminiFunctionDeclaration {
395 name: tool.name,
396 description: tool.description,
397 parameters_json_schema: Some(tool.input_schema),
398 }
399 }
400}
401
402impl From<Vec<LLMTool>> for GeminiTool {
403 fn from(tools: Vec<LLMTool>) -> Self {
404 GeminiTool {
405 function_declarations: tools.into_iter().map(|t| t.into()).collect(),
406 }
407 }
408}
409
410impl From<GeminiResponse> for LLMCompletionResponse {
411 fn from(response: GeminiResponse) -> Self {
412 let usage = response.usage_metadata.map(|u| LLMTokenUsage {
413 prompt_tokens: u.prompt_token_count.unwrap_or(0),
414 completion_tokens: u.candidates_token_count.unwrap_or(0),
415 total_tokens: u.total_token_count.unwrap_or(0),
416 prompt_tokens_details: None,
417 });
418
419 let choices = response
420 .candidates
421 .unwrap_or_default()
422 .into_iter()
423 .enumerate()
424 .map(|(index, candidate)| {
425 let message = candidate
426 .content
427 .map(|c| c.into())
428 .unwrap_or_else(|| LLMMessage {
429 role: "model".to_string(),
430 content: LLMMessageContent::String(String::new()),
431 });
432
433 let has_tool_calls = match &message.content {
434 LLMMessageContent::List(items) => items
435 .iter()
436 .any(|item| matches!(item, LLMMessageTypedContent::ToolCall { .. })),
437 _ => false,
438 };
439
440 let finish_reason = if has_tool_calls {
441 Some("tool_calls".to_string())
442 } else {
443 candidate.finish_reason.map(|s| s.to_lowercase())
444 };
445
446 LLMChoice {
447 finish_reason,
448 index: index as u32,
449 message,
450 }
451 })
452 .collect();
453
454 LLMCompletionResponse {
455 model: response
457 .model_version
458 .unwrap_or_else(|| "gemini".to_string()),
459 object: "chat.completion".to_string(),
460 choices,
461 created: chrono::Utc::now().timestamp_millis() as u64,
462 usage,
463 id: response
464 .response_id
465 .unwrap_or_else(|| "unknown".to_string()),
466 }
467 }
468}