ultrafast_models_sdk/providers/
gemini.rs1use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
2use crate::error::ProviderError;
3use crate::models::{
4 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
5 ImageRequest, ImageResponse, Role, SpeechRequest, SpeechResponse, StreamChunk, Usage,
6};
7use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
8use async_stream::stream;
9use std::collections::HashMap;
10use std::time::Instant;
11pub struct GeminiProvider {
14 http: HttpProviderClient,
15 config: ProviderConfig,
16}
17
18impl GeminiProvider {
19 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
20 let http = HttpProviderClient::new(
22 config.timeout,
23 config.base_url.clone(),
24 "https://generativelanguage.googleapis.com/v1beta",
25 &config.headers,
26 AuthStrategy::Header {
27 name: "x-goog-api-key".to_string(),
28 value: config.api_key.clone(),
29 },
30 )?;
31
32 Ok(Self { http, config })
33 }
34
35 fn map_model(&self, model: &str) -> String {
36 self.config
37 .model_mapping
38 .get(model)
39 .cloned()
40 .unwrap_or_else(|| model.to_string())
41 }
42
43 }
45
46#[async_trait::async_trait]
47impl Provider for GeminiProvider {
48 fn name(&self) -> &str {
49 "gemini"
50 }
51
52 fn supports_streaming(&self) -> bool {
53 true
54 }
55
56 fn supports_function_calling(&self) -> bool {
57 false }
59
60 fn supported_models(&self) -> Vec<String> {
61 vec![
62 "gemini-1.5-pro".to_string(),
63 "gemini-1.5-pro-latest".to_string(),
64 "gemini-1.5-flash".to_string(),
65 "gemini-1.5-flash-latest".to_string(),
66 "gemini-1.0-pro".to_string(),
67 "gemini-1.0-pro-vision".to_string(),
68 ]
69 }
70
71 async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
72 let model = self.map_model(&request.model);
73 let path = format!("/models/{model}:generateContent");
74
75 let gemini_request = self.convert_to_gemini_format(request);
77 let gemini_response: GeminiResponse = self.http.post_json(&path, &gemini_request).await?;
78 let chat_response = self.convert_from_gemini_format(gemini_response);
79 Ok(chat_response)
80 }
81
82 async fn stream_chat_completion(
83 &self,
84 request: ChatRequest,
85 ) -> Result<StreamResult, ProviderError> {
86 let model = self.map_model(&request.model);
87 let path = format!("/models/{model}:streamGenerateContent");
88
89 let gemini_request = self.convert_to_gemini_format(request);
91 let response = self.http.post_json_raw(&path, &gemini_request).await?;
92 if !response.status().is_success() {
93 return Err(map_error_response(response).await);
94 }
95
96 let stream = Box::pin(stream! {
97 let mut bytes_stream = response.bytes_stream();
98 let mut buffer = String::new();
99
100 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
101 match chunk_result {
102 Ok(chunk) => {
103 let chunk_str = String::from_utf8_lossy(&chunk);
104 buffer.push_str(&chunk_str);
105
106 while let Some(line_end) = buffer.find('\n') {
107 let line = buffer[..line_end].trim().to_string();
108 buffer = buffer[line_end + 1..].to_string();
109
110 if !line.is_empty() {
111 match serde_json::from_str::<serde_json::Value>(&line) {
113 Ok(gemini_chunk) => {
114 if let Some(candidates) = gemini_chunk.get("candidates")
115 .and_then(|c| c.as_array()) {
116 for candidate in candidates {
117 if let Some(content) = candidate.get("content")
118 .and_then(|c| c.get("parts"))
119 .and_then(|p| p.as_array())
120 .and_then(|parts| parts.first())
121 .and_then(|part| part.get("text"))
122 .and_then(|t| t.as_str()) {
123
124 let stream_chunk = StreamChunk {
125 id: "gemini-stream".to_string(),
126 object: "chat.completion.chunk".to_string(),
127 created: chrono::Utc::now().timestamp() as u64,
128 model: model.clone(),
129 choices: vec![crate::models::StreamChoice {
130 index: 0,
131 delta: crate::models::Delta {
132 role: Some(Role::Assistant),
133 content: Some(content.to_string()),
134 tool_calls: None,
135 },
136 finish_reason: None,
137 }],
138 };
139 yield Ok(stream_chunk);
140 }
141 }
142 }
143 }
144 Err(_) => {
145 continue;
147 }
148 }
149 }
150 }
151 }
152 Err(e) => {
153 tracing::error!("Stream error: {}", e);
154 break;
155 }
156 }
157 }
158 });
159
160 Ok(stream)
161 }
162
163 async fn embedding(
164 &self,
165 request: EmbeddingRequest,
166 ) -> Result<EmbeddingResponse, ProviderError> {
167 let model = self.map_model(&request.model);
168 let path = format!("/models/{model}:embedContent");
169
170 let gemini_request = self.convert_to_gemini_embedding_format(request);
172 let gemini_response: GeminiEmbeddingResponse =
173 self.http.post_json(&path, &gemini_request).await?;
174 let embedding_response = self.convert_from_gemini_embedding_format(gemini_response);
175 Ok(embedding_response)
176 }
177
178 async fn image_generation(
179 &self,
180 _request: ImageRequest,
181 ) -> Result<ImageResponse, ProviderError> {
182 Err(ProviderError::Configuration {
183 message: "Image generation not supported by Gemini".to_string(),
184 })
185 }
186
187 async fn audio_transcription(
188 &self,
189 _request: AudioRequest,
190 ) -> Result<AudioResponse, ProviderError> {
191 Err(ProviderError::Configuration {
192 message: "Audio transcription not supported by Gemini".to_string(),
193 })
194 }
195
196 async fn text_to_speech(
197 &self,
198 _request: SpeechRequest,
199 ) -> Result<SpeechResponse, ProviderError> {
200 Err(ProviderError::Configuration {
201 message: "Text-to-speech not supported by Gemini".to_string(),
202 })
203 }
204
205 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
206 let start = Instant::now();
207
208 let result = self.http.get_json::<serde_json::Value>("/models").await;
210 let latency = start.elapsed();
211 match result {
212 Ok(_) => Ok(ProviderHealth {
213 status: HealthStatus::Healthy,
214 latency_ms: Some(latency.as_millis() as u64),
215 error_rate: 0.0,
216 last_check: chrono::Utc::now(),
217 details: HashMap::new(),
218 }),
219 Err(e) => {
220 let mut details = HashMap::new();
221 details.insert("error".to_string(), e.to_string());
222 Ok(ProviderHealth {
223 status: HealthStatus::Degraded,
224 latency_ms: Some(latency.as_millis() as u64),
225 error_rate: 1.0,
226 last_check: chrono::Utc::now(),
227 details,
228 })
229 }
230 }
231 }
232}
233
234impl GeminiProvider {
235 fn convert_to_gemini_format(&self, request: ChatRequest) -> GeminiRequest {
236 let mut contents = Vec::new();
237
238 for message in &request.messages {
239 let role = match message.role {
240 Role::User => "user",
241 Role::Assistant => "model",
242 Role::System => "user", Role::Tool => "user", };
245
246 let parts = vec![GeminiPart {
247 text: message.content.clone(),
248 }];
249
250 contents.push(GeminiContent {
251 role: role.to_string(),
252 parts,
253 });
254 }
255
256 let generation_config = GeminiGenerationConfig {
257 temperature: request.temperature,
258 max_output_tokens: request.max_tokens.map(|t| t as i32),
259 top_p: request.top_p,
260 top_k: None,
261 };
262
263 GeminiRequest {
264 contents,
265 generation_config: Some(generation_config),
266 }
267 }
268
269 fn convert_from_gemini_format(&self, response: GeminiResponse) -> ChatResponse {
270 let mut choices = Vec::new();
271
272 for (index, candidate) in response.candidates.iter().enumerate() {
273 let content = candidate
274 .content
275 .parts
276 .iter()
277 .map(|part| part.text.clone())
278 .collect::<Vec<String>>()
279 .join("");
280
281 choices.push(crate::models::Choice {
282 index: index as u32,
283 message: crate::models::Message {
284 role: Role::Assistant,
285 content,
286 name: None,
287 tool_calls: None,
288 tool_call_id: None,
289 },
290 finish_reason: Some("stop".to_string()),
291 logprobs: None,
292 });
293 }
294
295 let usage = response.usage_metadata.map(|u| Usage {
296 prompt_tokens: u.prompt_token_count,
297 completion_tokens: u.candidates_token_count,
298 total_tokens: u.total_token_count,
299 });
300
301 ChatResponse {
302 id: "gemini-response".to_string(),
303 object: "chat.completion".to_string(),
304 created: chrono::Utc::now().timestamp() as u64,
305 model: "gemini-1.5-pro".to_string(),
306 choices,
307 usage,
308 system_fingerprint: None,
309 }
310 }
311
312 fn convert_to_gemini_embedding_format(
313 &self,
314 request: EmbeddingRequest,
315 ) -> GeminiEmbeddingRequest {
316 let text = match &request.input {
317 crate::models::EmbeddingInput::String(s) => s.clone(),
318 crate::models::EmbeddingInput::StringArray(arr) => arr.join(" "),
319 crate::models::EmbeddingInput::TokenArray(_) => "".to_string(), crate::models::EmbeddingInput::TokenArrayArray(_) => "".to_string(), };
322
323 let content = GeminiEmbeddingContent {
324 parts: vec![GeminiEmbeddingPart { text }],
325 };
326
327 GeminiEmbeddingRequest {
328 content: Some(content),
329 }
330 }
331
332 fn convert_from_gemini_embedding_format(
333 &self,
334 response: GeminiEmbeddingResponse,
335 ) -> EmbeddingResponse {
336 let embeddings = response.embedding.values;
337
338 EmbeddingResponse {
339 object: "list".to_string(),
340 data: vec![crate::models::Embedding {
341 object: "embedding".to_string(),
342 embedding: embeddings,
343 index: 0,
344 }],
345 model: "text-embedding-004".to_string(),
346 usage: Usage {
347 prompt_tokens: 0,
348 completion_tokens: 0,
349 total_tokens: 0,
350 },
351 }
352 }
353}
354
355#[derive(serde::Serialize)]
357struct GeminiRequest {
358 contents: Vec<GeminiContent>,
359 generation_config: Option<GeminiGenerationConfig>,
360}
361
362#[derive(serde::Serialize, serde::Deserialize)]
363struct GeminiContent {
364 role: String,
365 parts: Vec<GeminiPart>,
366}
367
368#[derive(serde::Serialize, serde::Deserialize)]
369struct GeminiPart {
370 text: String,
371}
372
373#[derive(serde::Serialize)]
374struct GeminiGenerationConfig {
375 temperature: Option<f32>,
376 max_output_tokens: Option<i32>,
377 top_p: Option<f32>,
378 top_k: Option<i32>,
379}
380
381#[derive(serde::Deserialize)]
382struct GeminiResponse {
383 candidates: Vec<GeminiCandidate>,
384 usage_metadata: Option<GeminiUsage>,
385}
386
387#[derive(serde::Deserialize)]
388struct GeminiCandidate {
389 content: GeminiContent,
390 #[allow(dead_code)]
391 finish_reason: Option<String>,
392}
393
394#[derive(serde::Deserialize)]
395struct GeminiUsage {
396 prompt_token_count: u32,
397 candidates_token_count: u32,
398 total_token_count: u32,
399}
400
401#[derive(serde::Serialize)]
402struct GeminiEmbeddingRequest {
403 content: Option<GeminiEmbeddingContent>,
404}
405
406#[derive(serde::Serialize)]
407struct GeminiEmbeddingContent {
408 parts: Vec<GeminiEmbeddingPart>,
409}
410
411#[derive(serde::Serialize)]
412struct GeminiEmbeddingPart {
413 text: String,
414}
415
416#[derive(serde::Deserialize)]
417struct GeminiEmbeddingResponse {
418 embedding: GeminiEmbedding,
419}
420
421#[derive(serde::Deserialize)]
422struct GeminiEmbedding {
423 values: Vec<f32>,
424}