ultrafast_models_sdk/providers/
gemini.rs

1use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
2use crate::error::ProviderError;
3use crate::models::{
4    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
5    ImageRequest, ImageResponse, Role, SpeechRequest, SpeechResponse, StreamChunk, Usage,
6};
7use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
8use async_stream::stream;
9use std::collections::HashMap;
10use std::time::Instant;
11// use futures::StreamExt;
12
13pub struct GeminiProvider {
14    http: HttpProviderClient,
15    config: ProviderConfig,
16}
17
18impl GeminiProvider {
19    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
20        // Gemini uses API key in header x-goog-api-key; use Header auth
21        let http = HttpProviderClient::new(
22            config.timeout,
23            config.base_url.clone(),
24            "https://generativelanguage.googleapis.com/v1beta",
25            &config.headers,
26            AuthStrategy::Header {
27                name: "x-goog-api-key".to_string(),
28                value: config.api_key.clone(),
29            },
30        )?;
31
32        Ok(Self { http, config })
33    }
34
35    fn map_model(&self, model: &str) -> String {
36        self.config
37            .model_mapping
38            .get(model)
39            .cloned()
40            .unwrap_or_else(|| model.to_string())
41    }
42
43    // Use shared map_error_response
44}
45
46#[async_trait::async_trait]
47impl Provider for GeminiProvider {
48    fn name(&self) -> &str {
49        "gemini"
50    }
51
52    fn supports_streaming(&self) -> bool {
53        true
54    }
55
56    fn supports_function_calling(&self) -> bool {
57        false // Gemini doesn't support function calling yet
58    }
59
60    fn supported_models(&self) -> Vec<String> {
61        vec![
62            "gemini-1.5-pro".to_string(),
63            "gemini-1.5-pro-latest".to_string(),
64            "gemini-1.5-flash".to_string(),
65            "gemini-1.5-flash-latest".to_string(),
66            "gemini-1.0-pro".to_string(),
67            "gemini-1.0-pro-vision".to_string(),
68        ]
69    }
70
71    async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
72        let model = self.map_model(&request.model);
73        let path = format!("/models/{model}:generateContent");
74
75        // Convert OpenAI format to Gemini format
76        let gemini_request = self.convert_to_gemini_format(request);
77        let gemini_response: GeminiResponse = self.http.post_json(&path, &gemini_request).await?;
78        let chat_response = self.convert_from_gemini_format(gemini_response);
79        Ok(chat_response)
80    }
81
82    async fn stream_chat_completion(
83        &self,
84        request: ChatRequest,
85    ) -> Result<StreamResult, ProviderError> {
86        let model = self.map_model(&request.model);
87        let path = format!("/models/{model}:streamGenerateContent");
88
89        // Convert to Gemini streaming format
90        let gemini_request = self.convert_to_gemini_format(request);
91        let response = self.http.post_json_raw(&path, &gemini_request).await?;
92        if !response.status().is_success() {
93            return Err(map_error_response(response).await);
94        }
95
96        let stream = Box::pin(stream! {
97            let mut bytes_stream = response.bytes_stream();
98            let mut buffer = String::new();
99
100            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
101                match chunk_result {
102                    Ok(chunk) => {
103                        let chunk_str = String::from_utf8_lossy(&chunk);
104                        buffer.push_str(&chunk_str);
105
106                        while let Some(line_end) = buffer.find('\n') {
107                            let line = buffer[..line_end].trim().to_string();
108                            buffer = buffer[line_end + 1..].to_string();
109
110                            if !line.is_empty() {
111                                // Try to parse as Gemini streaming response and convert to OpenAI format
112                                match serde_json::from_str::<serde_json::Value>(&line) {
113                                    Ok(gemini_chunk) => {
114                                        if let Some(candidates) = gemini_chunk.get("candidates")
115                                            .and_then(|c| c.as_array()) {
116                                            for candidate in candidates {
117                                                if let Some(content) = candidate.get("content")
118                                                    .and_then(|c| c.get("parts"))
119                                                    .and_then(|p| p.as_array())
120                                                    .and_then(|parts| parts.first())
121                                                    .and_then(|part| part.get("text"))
122                                                    .and_then(|t| t.as_str()) {
123
124                                                    let stream_chunk = StreamChunk {
125                                                        id: "gemini-stream".to_string(),
126                                                        object: "chat.completion.chunk".to_string(),
127                                                        created: chrono::Utc::now().timestamp() as u64,
128                                                        model: model.clone(),
129                                                        choices: vec![crate::models::StreamChoice {
130                                                            index: 0,
131                                                            delta: crate::models::Delta {
132                                                                role: Some(Role::Assistant),
133                                                                content: Some(content.to_string()),
134                                                                tool_calls: None,
135                                                            },
136                                                            finish_reason: None,
137                                                        }],
138                                                    };
139                                                    yield Ok(stream_chunk);
140                                                }
141                                            }
142                                        }
143                                    }
144                                    Err(_) => {
145                                        // Skip invalid JSON lines
146                                        continue;
147                                    }
148                                }
149                            }
150                        }
151                    }
152                    Err(e) => {
153                        tracing::error!("Stream error: {}", e);
154                        break;
155                    }
156                }
157            }
158        });
159
160        Ok(stream)
161    }
162
163    async fn embedding(
164        &self,
165        request: EmbeddingRequest,
166    ) -> Result<EmbeddingResponse, ProviderError> {
167        let model = self.map_model(&request.model);
168        let path = format!("/models/{model}:embedContent");
169
170        // Convert to Gemini embedding format
171        let gemini_request = self.convert_to_gemini_embedding_format(request);
172        let gemini_response: GeminiEmbeddingResponse =
173            self.http.post_json(&path, &gemini_request).await?;
174        let embedding_response = self.convert_from_gemini_embedding_format(gemini_response);
175        Ok(embedding_response)
176    }
177
178    async fn image_generation(
179        &self,
180        _request: ImageRequest,
181    ) -> Result<ImageResponse, ProviderError> {
182        Err(ProviderError::Configuration {
183            message: "Image generation not supported by Gemini".to_string(),
184        })
185    }
186
187    async fn audio_transcription(
188        &self,
189        _request: AudioRequest,
190    ) -> Result<AudioResponse, ProviderError> {
191        Err(ProviderError::Configuration {
192            message: "Audio transcription not supported by Gemini".to_string(),
193        })
194    }
195
196    async fn text_to_speech(
197        &self,
198        _request: SpeechRequest,
199    ) -> Result<SpeechResponse, ProviderError> {
200        Err(ProviderError::Configuration {
201            message: "Text-to-speech not supported by Gemini".to_string(),
202        })
203    }
204
205    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
206        let start = Instant::now();
207
208        // Try to list models as a health check
209        let result = self.http.get_json::<serde_json::Value>("/models").await;
210        let latency = start.elapsed();
211        match result {
212            Ok(_) => Ok(ProviderHealth {
213                status: HealthStatus::Healthy,
214                latency_ms: Some(latency.as_millis() as u64),
215                error_rate: 0.0,
216                last_check: chrono::Utc::now(),
217                details: HashMap::new(),
218            }),
219            Err(e) => {
220                let mut details = HashMap::new();
221                details.insert("error".to_string(), e.to_string());
222                Ok(ProviderHealth {
223                    status: HealthStatus::Degraded,
224                    latency_ms: Some(latency.as_millis() as u64),
225                    error_rate: 1.0,
226                    last_check: chrono::Utc::now(),
227                    details,
228                })
229            }
230        }
231    }
232}
233
234impl GeminiProvider {
235    fn convert_to_gemini_format(&self, request: ChatRequest) -> GeminiRequest {
236        let mut contents = Vec::new();
237
238        for message in &request.messages {
239            let role = match message.role {
240                Role::User => "user",
241                Role::Assistant => "model",
242                Role::System => "user", // Gemini doesn't have system messages, treat as user
243                Role::Tool => "user",   // Gemini doesn't have tool messages, treat as user
244            };
245
246            let parts = vec![GeminiPart {
247                text: message.content.clone(),
248            }];
249
250            contents.push(GeminiContent {
251                role: role.to_string(),
252                parts,
253            });
254        }
255
256        let generation_config = GeminiGenerationConfig {
257            temperature: request.temperature,
258            max_output_tokens: request.max_tokens.map(|t| t as i32),
259            top_p: request.top_p,
260            top_k: None,
261        };
262
263        GeminiRequest {
264            contents,
265            generation_config: Some(generation_config),
266        }
267    }
268
269    fn convert_from_gemini_format(&self, response: GeminiResponse) -> ChatResponse {
270        let mut choices = Vec::new();
271
272        for (index, candidate) in response.candidates.iter().enumerate() {
273            let content = candidate
274                .content
275                .parts
276                .iter()
277                .map(|part| part.text.clone())
278                .collect::<Vec<String>>()
279                .join("");
280
281            choices.push(crate::models::Choice {
282                index: index as u32,
283                message: crate::models::Message {
284                    role: Role::Assistant,
285                    content,
286                    name: None,
287                    tool_calls: None,
288                    tool_call_id: None,
289                },
290                finish_reason: Some("stop".to_string()),
291                logprobs: None,
292            });
293        }
294
295        let usage = response.usage_metadata.map(|u| Usage {
296            prompt_tokens: u.prompt_token_count,
297            completion_tokens: u.candidates_token_count,
298            total_tokens: u.total_token_count,
299        });
300
301        ChatResponse {
302            id: "gemini-response".to_string(),
303            object: "chat.completion".to_string(),
304            created: chrono::Utc::now().timestamp() as u64,
305            model: "gemini-1.5-pro".to_string(),
306            choices,
307            usage,
308            system_fingerprint: None,
309        }
310    }
311
312    fn convert_to_gemini_embedding_format(
313        &self,
314        request: EmbeddingRequest,
315    ) -> GeminiEmbeddingRequest {
316        let text = match &request.input {
317            crate::models::EmbeddingInput::String(s) => s.clone(),
318            crate::models::EmbeddingInput::StringArray(arr) => arr.join(" "),
319            crate::models::EmbeddingInput::TokenArray(_) => "".to_string(), // Not supported by Gemini
320            crate::models::EmbeddingInput::TokenArrayArray(_) => "".to_string(), // Not supported by Gemini
321        };
322
323        let content = GeminiEmbeddingContent {
324            parts: vec![GeminiEmbeddingPart { text }],
325        };
326
327        GeminiEmbeddingRequest {
328            content: Some(content),
329        }
330    }
331
332    fn convert_from_gemini_embedding_format(
333        &self,
334        response: GeminiEmbeddingResponse,
335    ) -> EmbeddingResponse {
336        let embeddings = response.embedding.values;
337
338        EmbeddingResponse {
339            object: "list".to_string(),
340            data: vec![crate::models::Embedding {
341                object: "embedding".to_string(),
342                embedding: embeddings,
343                index: 0,
344            }],
345            model: "text-embedding-004".to_string(),
346            usage: Usage {
347                prompt_tokens: 0,
348                completion_tokens: 0,
349                total_tokens: 0,
350            },
351        }
352    }
353}
354
355// Gemini API request/response structures
356#[derive(serde::Serialize)]
357struct GeminiRequest {
358    contents: Vec<GeminiContent>,
359    generation_config: Option<GeminiGenerationConfig>,
360}
361
362#[derive(serde::Serialize, serde::Deserialize)]
363struct GeminiContent {
364    role: String,
365    parts: Vec<GeminiPart>,
366}
367
368#[derive(serde::Serialize, serde::Deserialize)]
369struct GeminiPart {
370    text: String,
371}
372
373#[derive(serde::Serialize)]
374struct GeminiGenerationConfig {
375    temperature: Option<f32>,
376    max_output_tokens: Option<i32>,
377    top_p: Option<f32>,
378    top_k: Option<i32>,
379}
380
381#[derive(serde::Deserialize)]
382struct GeminiResponse {
383    candidates: Vec<GeminiCandidate>,
384    usage_metadata: Option<GeminiUsage>,
385}
386
387#[derive(serde::Deserialize)]
388struct GeminiCandidate {
389    content: GeminiContent,
390    #[allow(dead_code)]
391    finish_reason: Option<String>,
392}
393
394#[derive(serde::Deserialize)]
395struct GeminiUsage {
396    prompt_token_count: u32,
397    candidates_token_count: u32,
398    total_token_count: u32,
399}
400
401#[derive(serde::Serialize)]
402struct GeminiEmbeddingRequest {
403    content: Option<GeminiEmbeddingContent>,
404}
405
406#[derive(serde::Serialize)]
407struct GeminiEmbeddingContent {
408    parts: Vec<GeminiEmbeddingPart>,
409}
410
411#[derive(serde::Serialize)]
412struct GeminiEmbeddingPart {
413    text: String,
414}
415
416#[derive(serde::Deserialize)]
417struct GeminiEmbeddingResponse {
418    embedding: GeminiEmbedding,
419}
420
421#[derive(serde::Deserialize)]
422struct GeminiEmbedding {
423    values: Vec<f32>,
424}