ultrafast_models_sdk/providers/
google.rs

1use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
2use crate::error::ProviderError;
3use crate::models::{
4    AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
5    ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
6};
7use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
8use async_stream::stream;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Instant;
12
13pub struct GoogleVertexAIProvider {
14    http: HttpProviderClient,
15    config: ProviderConfig,
16    base_url: String,
17    #[allow(dead_code)]
18    project_id: String,
19    location: String,
20}
21
22impl GoogleVertexAIProvider {
23    pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
24        let project_id = config.headers.get("project-id").cloned().ok_or_else(|| {
25            ProviderError::Configuration {
26                message: "project-id is required for Google Vertex AI".to_string(),
27            }
28        })?;
29
30        let location = config
31            .headers
32            .get("location")
33            .cloned()
34            .unwrap_or_else(|| "us-central1".to_string());
35
36        let base_url = config.base_url.clone().unwrap_or_else(|| {
37            format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}")
38        });
39
40        let http = HttpProviderClient::new(
41            config.timeout,
42            Some(base_url.clone()),
43            &base_url,
44            &config.headers,
45            AuthStrategy::Bearer {
46                token: config.api_key.clone(),
47            },
48        )?;
49
50        Ok(Self {
51            http,
52            config,
53            base_url,
54            project_id,
55            location,
56        })
57    }
58
59    fn build_url(&self, endpoint: &str) -> String {
60        format!(
61            "{}/locations/{}/publishers/google/models/{}:predict",
62            self.base_url, self.location, endpoint
63        )
64    }
65
66    #[allow(dead_code)]
67    fn build_headers(&self) -> reqwest::header::HeaderMap {
68        let mut headers = reqwest::header::HeaderMap::new();
69
70        headers.insert(
71            "Authorization",
72            format!("Bearer {}", self.config.api_key).parse().unwrap(),
73        );
74
75        headers.insert("Content-Type", "application/json".parse().unwrap());
76
77        for (key, value) in &self.config.headers {
78            if let (Ok(header_name), Ok(header_value)) =
79                (key.parse::<reqwest::header::HeaderName>(), value.parse())
80            {
81                headers.insert(header_name, header_value);
82            }
83        }
84
85        headers
86    }
87
88    fn map_model(&self, model: &str) -> String {
89        self.config
90            .model_mapping
91            .get(model)
92            .cloned()
93            .unwrap_or_else(|| {
94                // Map common model names to Vertex AI equivalents
95                match model {
96                    "gpt-4" | "gpt-3.5-turbo" => "chat-bison".to_string(),
97                    "text-embedding-ada-002" => "textembedding-gecko".to_string(),
98                    _ => model.to_string(),
99                }
100            })
101    }
102
103    #[allow(dead_code)]
104    async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
105        let status = response.status();
106
107        match response.text().await {
108            Ok(body) => {
109                if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
110                    let message = error_json
111                        .get("error")
112                        .and_then(|e| e.get("message"))
113                        .and_then(|m| m.as_str())
114                        .unwrap_or("Unknown API error")
115                        .to_string();
116
117                    match status.as_u16() {
118                        401 => ProviderError::InvalidApiKey,
119                        404 => ProviderError::ModelNotFound {
120                            model: "unknown".to_string(),
121                        },
122                        429 => ProviderError::RateLimit,
123                        _ => ProviderError::Api {
124                            code: status.as_u16(),
125                            message,
126                        },
127                    }
128                } else {
129                    ProviderError::Api {
130                        code: status.as_u16(),
131                        message: body,
132                    }
133                }
134            }
135            Err(_) => ProviderError::Api {
136                code: status.as_u16(),
137                message: "Failed to read error response".to_string(),
138            },
139        }
140    }
141}
142
143#[async_trait::async_trait]
144impl Provider for GoogleVertexAIProvider {
145    fn name(&self) -> &str {
146        "google-vertex-ai"
147    }
148
149    fn supports_streaming(&self) -> bool {
150        true
151    }
152
153    fn supports_function_calling(&self) -> bool {
154        false
155    }
156
157    fn supported_models(&self) -> Vec<String> {
158        vec![
159            "chat-bison".to_string(),
160            "chat-bison-32k".to_string(),
161            "text-bison".to_string(),
162            "text-bison-32k".to_string(),
163            "gemini-pro".to_string(),
164            "gemini-pro-vision".to_string(),
165            "textembedding-gecko".to_string(),
166            "textembedding-gecko-multilingual".to_string(),
167        ]
168    }
169
170    async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
171        let model = self.map_model(&request.model);
172        let url = self.build_url(&model);
173        // Convert OpenAI format to Vertex AI format
174        let vertex_request = self.convert_to_vertex_format(request);
175
176        let vertex_response: VertexAIResponse = self.http.post_json(&url, &vertex_request).await?;
177        let chat_response = self.convert_from_vertex_format(vertex_response);
178        Ok(chat_response)
179    }
180
181    async fn stream_chat_completion(
182        &self,
183        request: ChatRequest,
184    ) -> Result<StreamResult, ProviderError> {
185        let model = self.map_model(&request.model);
186        let url = format!(
187            "{}/locations/{}/publishers/google/models/{}:streamGenerateContent",
188            self.base_url, self.location, model
189        );
190        // Convert to Vertex AI streaming format
191        let vertex_request = self.convert_to_vertex_streaming_format(request);
192
193        let response = self.http.post_json_raw(&url, &vertex_request).await?;
194        if !response.status().is_success() {
195            return Err(map_error_response(response).await);
196        }
197
198        let stream = Box::pin(stream! {
199            let mut bytes_stream = response.bytes_stream();
200            let mut buffer = String::new();
201
202            while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
203                match chunk_result {
204                    Ok(chunk) => {
205                        let chunk_str = String::from_utf8_lossy(&chunk);
206                        buffer.push_str(&chunk_str);
207
208                        while let Some(line_end) = buffer.find('\n') {
209                            let line = buffer[..line_end].trim().to_string();
210                            buffer = buffer[line_end + 1..].to_string();
211
212                            if !line.is_empty() {
213                                // Try to parse as Vertex AI streaming response and convert to OpenAI format
214                                match serde_json::from_str::<serde_json::Value>(&line) {
215                                    Ok(vertex_chunk) => {
216                                        if let Some(candidates) = vertex_chunk.get("candidates")
217                                            .and_then(|c| c.as_array()) {
218                                            for candidate in candidates {
219                                                if let Some(content) = candidate.get("content")
220                                                    .and_then(|c| c.get("parts"))
221                                                    .and_then(|p| p.as_array())
222                                                    .and_then(|parts| parts.first())
223                                                    .and_then(|part| part.get("text"))
224                                                    .and_then(|t| t.as_str()) {
225
226                                                    let stream_chunk = StreamChunk {
227                                                        id: "vertex-stream".to_string(),
228                                                        object: "chat.completion.chunk".to_string(),
229                                                        created: chrono::Utc::now().timestamp() as u64,
230                                                        model: "chat-bison".to_string(),
231                                                        choices: vec![crate::models::StreamChoice {
232                                                            index: 0,
233                                                            delta: crate::models::Delta {
234                                                                role: None,
235                                                                content: Some(content.to_string()),
236                                                                tool_calls: None,
237                                                            },
238                                                            finish_reason: None,
239                                                        }],
240                                                    };
241                                                    yield Ok(stream_chunk);
242                                                }
243                                            }
244                                        }
245                                    }
246                                    Err(e) => yield Err(ProviderError::Serialization(e)),
247                                }
248                            }
249                        }
250                    }
251                    Err(e) => yield Err(ProviderError::Http(e)),
252                }
253            }
254        });
255
256        Ok(stream)
257    }
258
259    async fn embedding(
260        &self,
261        request: EmbeddingRequest,
262    ) -> Result<EmbeddingResponse, ProviderError> {
263        let model = self.map_model(&request.model);
264        let url = self.build_url(&model);
265        // Convert to Vertex AI embedding format
266        let vertex_embedding_request = VertexAIEmbeddingRequest {
267            instances: vec![VertexAIEmbeddingInstance {
268                content: match request.input {
269                    crate::models::EmbeddingInput::String(s) => s,
270                    _ => {
271                        return Err(ProviderError::Configuration {
272                            message:
273                                "Only string input is supported for Google Vertex AI embeddings"
274                                    .to_string(),
275                        })
276                    }
277                },
278            }],
279        };
280
281        let vertex_response: VertexAIEmbeddingResponse =
282            self.http.post_json(&url, &vertex_embedding_request).await?;
283
284        // Convert back to OpenAI format
285        let embedding_response = EmbeddingResponse {
286            object: "list".to_string(),
287            data: vertex_response
288                .predictions
289                .into_iter()
290                .map(|pred| crate::models::Embedding {
291                    object: "embedding".to_string(),
292                    embedding: pred.embeddings.values,
293                    index: 0,
294                })
295                .collect(),
296            model: request.model.clone(),
297            usage: crate::models::Usage {
298                prompt_tokens: 0,
299                completion_tokens: 0,
300                total_tokens: 0,
301            },
302        };
303
304        Ok(embedding_response)
305    }
306
307    async fn image_generation(
308        &self,
309        _request: ImageRequest,
310    ) -> Result<ImageResponse, ProviderError> {
311        Err(ProviderError::Configuration {
312            message: "Google Vertex AI does not support image generation via this API".to_string(),
313        })
314    }
315
316    async fn audio_transcription(
317        &self,
318        _request: AudioRequest,
319    ) -> Result<AudioResponse, ProviderError> {
320        Err(ProviderError::Configuration {
321            message: "Google Vertex AI does not support audio transcription via this API"
322                .to_string(),
323        })
324    }
325
326    async fn text_to_speech(
327        &self,
328        _request: SpeechRequest,
329    ) -> Result<SpeechResponse, ProviderError> {
330        Err(ProviderError::Configuration {
331            message: "Google Vertex AI does not support text-to-speech via this API".to_string(),
332        })
333    }
334
335    async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
336        let start = Instant::now();
337
338        // Use a simple models list request for health check
339        let url = format!(
340            "{}/locations/{}/publishers/google/models",
341            self.base_url, self.location
342        );
343        let response = self.http.get_json::<serde_json::Value>(&url).await;
344
345        let latency_ms = start.elapsed().as_millis() as u64;
346
347        match response {
348            Ok(_) => Ok(ProviderHealth {
349                status: HealthStatus::Healthy,
350                latency_ms: Some(latency_ms),
351                error_rate: 0.0,
352                last_check: chrono::Utc::now(),
353                details: HashMap::new(),
354            }),
355            Err(e) => {
356                let mut details = HashMap::new();
357                details.insert("error".to_string(), e.to_string());
358
359                Ok(ProviderHealth {
360                    status: HealthStatus::Degraded,
361                    latency_ms: Some(latency_ms),
362                    error_rate: 1.0,
363                    last_check: chrono::Utc::now(),
364                    details,
365                })
366            }
367        }
368    }
369}
370
371impl GoogleVertexAIProvider {
372    fn convert_to_vertex_streaming_format(&self, request: ChatRequest) -> VertexAIStreamRequest {
373        let contents = request
374            .messages
375            .into_iter()
376            .map(|msg| {
377                VertexAIContent {
378                    role: match msg.role {
379                        crate::models::Role::System => "user".to_string(), // Vertex AI doesn't have system role
380                        crate::models::Role::User => "user".to_string(),
381                        crate::models::Role::Assistant => "model".to_string(),
382                        crate::models::Role::Tool => "user".to_string(),
383                    },
384                    parts: vec![VertexAIPart { text: msg.content }],
385                }
386            })
387            .collect();
388
389        let generation_config = VertexAIGenerationConfig {
390            temperature: request.temperature,
391            max_output_tokens: request.max_tokens.map(|t| t as i32),
392            top_p: request.top_p,
393            top_k: None,
394        };
395
396        VertexAIStreamRequest {
397            contents,
398            generation_config: Some(generation_config),
399        }
400    }
401
402    fn convert_to_vertex_format(&self, request: ChatRequest) -> VertexAIRequest {
403        let messages = request
404            .messages
405            .into_iter()
406            .map(|msg| VertexAIMessage {
407                author: match msg.role {
408                    crate::models::Role::System => "system".to_string(),
409                    crate::models::Role::User => "user".to_string(),
410                    crate::models::Role::Assistant => "assistant".to_string(),
411                    crate::models::Role::Tool => "user".to_string(),
412                },
413                content: msg.content,
414            })
415            .collect();
416
417        let parameters = VertexAIParameters {
418            temperature: request.temperature.unwrap_or(0.7),
419            max_output_tokens: request.max_tokens.unwrap_or(1024) as i32,
420            top_p: request.top_p,
421            top_k: None,
422        };
423
424        VertexAIRequest {
425            instances: vec![VertexAIInstance { messages }],
426            parameters: Some(parameters),
427        }
428    }
429
430    fn convert_from_vertex_format(&self, response: VertexAIResponse) -> ChatResponse {
431        let choices = response
432            .predictions
433            .into_iter()
434            .flat_map(|pred| {
435                pred.candidates
436                    .into_iter()
437                    .map(|candidate| crate::models::Choice {
438                        index: 0,
439                        message: crate::models::Message {
440                            role: crate::models::Role::Assistant,
441                            content: candidate.content,
442                            name: None,
443                            tool_calls: None,
444                            tool_call_id: None,
445                        },
446                        finish_reason: Some("stop".to_string()),
447                        logprobs: None,
448                    })
449            })
450            .collect();
451
452        ChatResponse {
453            id: uuid::Uuid::new_v4().to_string(),
454            object: "chat.completion".to_string(),
455            created: chrono::Utc::now().timestamp() as u64,
456            model: "chat-bison".to_string(),
457            choices,
458            usage: None,
459            system_fingerprint: None,
460        }
461    }
462}
463
464// Vertex AI specific data structures
465#[derive(Debug, Serialize, Deserialize)]
466struct VertexAIRequest {
467    instances: Vec<VertexAIInstance>,
468    parameters: Option<VertexAIParameters>,
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472struct VertexAIInstance {
473    messages: Vec<VertexAIMessage>,
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477struct VertexAIMessage {
478    author: String,
479    content: String,
480}
481
482#[derive(Debug, Serialize, Deserialize)]
483struct VertexAIParameters {
484    temperature: f32,
485    max_output_tokens: i32,
486    top_p: Option<f32>,
487    top_k: Option<i32>,
488}
489
490#[derive(Debug, Serialize, Deserialize)]
491struct VertexAIResponse {
492    predictions: Vec<VertexAIPrediction>,
493}
494
495#[derive(Debug, Serialize, Deserialize)]
496struct VertexAIPrediction {
497    candidates: Vec<VertexAICandidate>,
498}
499
500#[derive(Debug, Serialize, Deserialize)]
501struct VertexAICandidate {
502    content: String,
503}
504
505#[derive(Debug, Serialize, Deserialize)]
506struct VertexAIEmbeddingRequest {
507    instances: Vec<VertexAIEmbeddingInstance>,
508}
509
510#[derive(Debug, Serialize, Deserialize)]
511struct VertexAIEmbeddingInstance {
512    content: String,
513}
514
515#[derive(Debug, Serialize, Deserialize)]
516struct VertexAIEmbeddingResponse {
517    predictions: Vec<VertexAIEmbeddingPrediction>,
518}
519
520#[derive(Debug, Serialize, Deserialize)]
521struct VertexAIEmbeddingPrediction {
522    embeddings: VertexAIEmbeddings,
523}
524
525#[derive(Debug, Serialize, Deserialize)]
526struct VertexAIEmbeddings {
527    values: Vec<f32>,
528}
529
530// Streaming-specific structures
531#[derive(Debug, Serialize, Deserialize)]
532struct VertexAIStreamRequest {
533    contents: Vec<VertexAIContent>,
534    generation_config: Option<VertexAIGenerationConfig>,
535}
536
537#[derive(Debug, Serialize, Deserialize)]
538struct VertexAIContent {
539    role: String,
540    parts: Vec<VertexAIPart>,
541}
542
543#[derive(Debug, Serialize, Deserialize)]
544struct VertexAIPart {
545    text: String,
546}
547
548#[derive(Debug, Serialize, Deserialize)]
549struct VertexAIGenerationConfig {
550    temperature: Option<f32>,
551    max_output_tokens: Option<i32>,
552    top_p: Option<f32>,
553    top_k: Option<i32>,
554}