1use super::http_client::{map_error_response, AuthStrategy, HttpProviderClient};
2use crate::error::ProviderError;
3use crate::models::{
4 AudioRequest, AudioResponse, ChatRequest, ChatResponse, EmbeddingRequest, EmbeddingResponse,
5 ImageRequest, ImageResponse, SpeechRequest, SpeechResponse, StreamChunk,
6};
7use crate::providers::{HealthStatus, Provider, ProviderConfig, ProviderHealth, StreamResult};
8use async_stream::stream;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Instant;
12
13pub struct GoogleVertexAIProvider {
14 http: HttpProviderClient,
15 config: ProviderConfig,
16 base_url: String,
17 #[allow(dead_code)]
18 project_id: String,
19 location: String,
20}
21
22impl GoogleVertexAIProvider {
23 pub fn new(config: ProviderConfig) -> Result<Self, ProviderError> {
24 let project_id = config.headers.get("project-id").cloned().ok_or_else(|| {
25 ProviderError::Configuration {
26 message: "project-id is required for Google Vertex AI".to_string(),
27 }
28 })?;
29
30 let location = config
31 .headers
32 .get("location")
33 .cloned()
34 .unwrap_or_else(|| "us-central1".to_string());
35
36 let base_url = config.base_url.clone().unwrap_or_else(|| {
37 format!("https://{location}-aiplatform.googleapis.com/v1/projects/{project_id}")
38 });
39
40 let http = HttpProviderClient::new(
41 config.timeout,
42 Some(base_url.clone()),
43 &base_url,
44 &config.headers,
45 AuthStrategy::Bearer {
46 token: config.api_key.clone(),
47 },
48 )?;
49
50 Ok(Self {
51 http,
52 config,
53 base_url,
54 project_id,
55 location,
56 })
57 }
58
59 fn build_url(&self, endpoint: &str) -> String {
60 format!(
61 "{}/locations/{}/publishers/google/models/{}:predict",
62 self.base_url, self.location, endpoint
63 )
64 }
65
66 #[allow(dead_code)]
67 fn build_headers(&self) -> reqwest::header::HeaderMap {
68 let mut headers = reqwest::header::HeaderMap::new();
69
70 headers.insert(
71 "Authorization",
72 format!("Bearer {}", self.config.api_key).parse().unwrap(),
73 );
74
75 headers.insert("Content-Type", "application/json".parse().unwrap());
76
77 for (key, value) in &self.config.headers {
78 if let (Ok(header_name), Ok(header_value)) =
79 (key.parse::<reqwest::header::HeaderName>(), value.parse())
80 {
81 headers.insert(header_name, header_value);
82 }
83 }
84
85 headers
86 }
87
88 fn map_model(&self, model: &str) -> String {
89 self.config
90 .model_mapping
91 .get(model)
92 .cloned()
93 .unwrap_or_else(|| {
94 match model {
96 "gpt-4" | "gpt-3.5-turbo" => "chat-bison".to_string(),
97 "text-embedding-ada-002" => "textembedding-gecko".to_string(),
98 _ => model.to_string(),
99 }
100 })
101 }
102
103 #[allow(dead_code)]
104 async fn handle_error_response(&self, response: reqwest::Response) -> ProviderError {
105 let status = response.status();
106
107 match response.text().await {
108 Ok(body) => {
109 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(&body) {
110 let message = error_json
111 .get("error")
112 .and_then(|e| e.get("message"))
113 .and_then(|m| m.as_str())
114 .unwrap_or("Unknown API error")
115 .to_string();
116
117 match status.as_u16() {
118 401 => ProviderError::InvalidApiKey,
119 404 => ProviderError::ModelNotFound {
120 model: "unknown".to_string(),
121 },
122 429 => ProviderError::RateLimit,
123 _ => ProviderError::Api {
124 code: status.as_u16(),
125 message,
126 },
127 }
128 } else {
129 ProviderError::Api {
130 code: status.as_u16(),
131 message: body,
132 }
133 }
134 }
135 Err(_) => ProviderError::Api {
136 code: status.as_u16(),
137 message: "Failed to read error response".to_string(),
138 },
139 }
140 }
141}
142
143#[async_trait::async_trait]
144impl Provider for GoogleVertexAIProvider {
145 fn name(&self) -> &str {
146 "google-vertex-ai"
147 }
148
149 fn supports_streaming(&self) -> bool {
150 true
151 }
152
153 fn supports_function_calling(&self) -> bool {
154 false
155 }
156
157 fn supported_models(&self) -> Vec<String> {
158 vec![
159 "chat-bison".to_string(),
160 "chat-bison-32k".to_string(),
161 "text-bison".to_string(),
162 "text-bison-32k".to_string(),
163 "gemini-pro".to_string(),
164 "gemini-pro-vision".to_string(),
165 "textembedding-gecko".to_string(),
166 "textembedding-gecko-multilingual".to_string(),
167 ]
168 }
169
170 async fn chat_completion(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
171 let model = self.map_model(&request.model);
172 let url = self.build_url(&model);
173 let vertex_request = self.convert_to_vertex_format(request);
175
176 let vertex_response: VertexAIResponse = self.http.post_json(&url, &vertex_request).await?;
177 let chat_response = self.convert_from_vertex_format(vertex_response);
178 Ok(chat_response)
179 }
180
181 async fn stream_chat_completion(
182 &self,
183 request: ChatRequest,
184 ) -> Result<StreamResult, ProviderError> {
185 let model = self.map_model(&request.model);
186 let url = format!(
187 "{}/locations/{}/publishers/google/models/{}:streamGenerateContent",
188 self.base_url, self.location, model
189 );
190 let vertex_request = self.convert_to_vertex_streaming_format(request);
192
193 let response = self.http.post_json_raw(&url, &vertex_request).await?;
194 if !response.status().is_success() {
195 return Err(map_error_response(response).await);
196 }
197
198 let stream = Box::pin(stream! {
199 let mut bytes_stream = response.bytes_stream();
200 let mut buffer = String::new();
201
202 while let Some(chunk_result) = futures::StreamExt::next(&mut bytes_stream).await {
203 match chunk_result {
204 Ok(chunk) => {
205 let chunk_str = String::from_utf8_lossy(&chunk);
206 buffer.push_str(&chunk_str);
207
208 while let Some(line_end) = buffer.find('\n') {
209 let line = buffer[..line_end].trim().to_string();
210 buffer = buffer[line_end + 1..].to_string();
211
212 if !line.is_empty() {
213 match serde_json::from_str::<serde_json::Value>(&line) {
215 Ok(vertex_chunk) => {
216 if let Some(candidates) = vertex_chunk.get("candidates")
217 .and_then(|c| c.as_array()) {
218 for candidate in candidates {
219 if let Some(content) = candidate.get("content")
220 .and_then(|c| c.get("parts"))
221 .and_then(|p| p.as_array())
222 .and_then(|parts| parts.first())
223 .and_then(|part| part.get("text"))
224 .and_then(|t| t.as_str()) {
225
226 let stream_chunk = StreamChunk {
227 id: "vertex-stream".to_string(),
228 object: "chat.completion.chunk".to_string(),
229 created: chrono::Utc::now().timestamp() as u64,
230 model: "chat-bison".to_string(),
231 choices: vec![crate::models::StreamChoice {
232 index: 0,
233 delta: crate::models::Delta {
234 role: None,
235 content: Some(content.to_string()),
236 tool_calls: None,
237 },
238 finish_reason: None,
239 }],
240 };
241 yield Ok(stream_chunk);
242 }
243 }
244 }
245 }
246 Err(e) => yield Err(ProviderError::Serialization(e)),
247 }
248 }
249 }
250 }
251 Err(e) => yield Err(ProviderError::Http(e)),
252 }
253 }
254 });
255
256 Ok(stream)
257 }
258
259 async fn embedding(
260 &self,
261 request: EmbeddingRequest,
262 ) -> Result<EmbeddingResponse, ProviderError> {
263 let model = self.map_model(&request.model);
264 let url = self.build_url(&model);
265 let vertex_embedding_request = VertexAIEmbeddingRequest {
267 instances: vec![VertexAIEmbeddingInstance {
268 content: match request.input {
269 crate::models::EmbeddingInput::String(s) => s,
270 _ => {
271 return Err(ProviderError::Configuration {
272 message:
273 "Only string input is supported for Google Vertex AI embeddings"
274 .to_string(),
275 })
276 }
277 },
278 }],
279 };
280
281 let vertex_response: VertexAIEmbeddingResponse =
282 self.http.post_json(&url, &vertex_embedding_request).await?;
283
284 let embedding_response = EmbeddingResponse {
286 object: "list".to_string(),
287 data: vertex_response
288 .predictions
289 .into_iter()
290 .map(|pred| crate::models::Embedding {
291 object: "embedding".to_string(),
292 embedding: pred.embeddings.values,
293 index: 0,
294 })
295 .collect(),
296 model: request.model.clone(),
297 usage: crate::models::Usage {
298 prompt_tokens: 0,
299 completion_tokens: 0,
300 total_tokens: 0,
301 },
302 };
303
304 Ok(embedding_response)
305 }
306
307 async fn image_generation(
308 &self,
309 _request: ImageRequest,
310 ) -> Result<ImageResponse, ProviderError> {
311 Err(ProviderError::Configuration {
312 message: "Google Vertex AI does not support image generation via this API".to_string(),
313 })
314 }
315
316 async fn audio_transcription(
317 &self,
318 _request: AudioRequest,
319 ) -> Result<AudioResponse, ProviderError> {
320 Err(ProviderError::Configuration {
321 message: "Google Vertex AI does not support audio transcription via this API"
322 .to_string(),
323 })
324 }
325
326 async fn text_to_speech(
327 &self,
328 _request: SpeechRequest,
329 ) -> Result<SpeechResponse, ProviderError> {
330 Err(ProviderError::Configuration {
331 message: "Google Vertex AI does not support text-to-speech via this API".to_string(),
332 })
333 }
334
335 async fn health_check(&self) -> Result<ProviderHealth, ProviderError> {
336 let start = Instant::now();
337
338 let url = format!(
340 "{}/locations/{}/publishers/google/models",
341 self.base_url, self.location
342 );
343 let response = self.http.get_json::<serde_json::Value>(&url).await;
344
345 let latency_ms = start.elapsed().as_millis() as u64;
346
347 match response {
348 Ok(_) => Ok(ProviderHealth {
349 status: HealthStatus::Healthy,
350 latency_ms: Some(latency_ms),
351 error_rate: 0.0,
352 last_check: chrono::Utc::now(),
353 details: HashMap::new(),
354 }),
355 Err(e) => {
356 let mut details = HashMap::new();
357 details.insert("error".to_string(), e.to_string());
358
359 Ok(ProviderHealth {
360 status: HealthStatus::Degraded,
361 latency_ms: Some(latency_ms),
362 error_rate: 1.0,
363 last_check: chrono::Utc::now(),
364 details,
365 })
366 }
367 }
368 }
369}
370
371impl GoogleVertexAIProvider {
372 fn convert_to_vertex_streaming_format(&self, request: ChatRequest) -> VertexAIStreamRequest {
373 let contents = request
374 .messages
375 .into_iter()
376 .map(|msg| {
377 VertexAIContent {
378 role: match msg.role {
379 crate::models::Role::System => "user".to_string(), crate::models::Role::User => "user".to_string(),
381 crate::models::Role::Assistant => "model".to_string(),
382 crate::models::Role::Tool => "user".to_string(),
383 },
384 parts: vec![VertexAIPart { text: msg.content }],
385 }
386 })
387 .collect();
388
389 let generation_config = VertexAIGenerationConfig {
390 temperature: request.temperature,
391 max_output_tokens: request.max_tokens.map(|t| t as i32),
392 top_p: request.top_p,
393 top_k: None,
394 };
395
396 VertexAIStreamRequest {
397 contents,
398 generation_config: Some(generation_config),
399 }
400 }
401
402 fn convert_to_vertex_format(&self, request: ChatRequest) -> VertexAIRequest {
403 let messages = request
404 .messages
405 .into_iter()
406 .map(|msg| VertexAIMessage {
407 author: match msg.role {
408 crate::models::Role::System => "system".to_string(),
409 crate::models::Role::User => "user".to_string(),
410 crate::models::Role::Assistant => "assistant".to_string(),
411 crate::models::Role::Tool => "user".to_string(),
412 },
413 content: msg.content,
414 })
415 .collect();
416
417 let parameters = VertexAIParameters {
418 temperature: request.temperature.unwrap_or(0.7),
419 max_output_tokens: request.max_tokens.unwrap_or(1024) as i32,
420 top_p: request.top_p,
421 top_k: None,
422 };
423
424 VertexAIRequest {
425 instances: vec![VertexAIInstance { messages }],
426 parameters: Some(parameters),
427 }
428 }
429
430 fn convert_from_vertex_format(&self, response: VertexAIResponse) -> ChatResponse {
431 let choices = response
432 .predictions
433 .into_iter()
434 .flat_map(|pred| {
435 pred.candidates
436 .into_iter()
437 .map(|candidate| crate::models::Choice {
438 index: 0,
439 message: crate::models::Message {
440 role: crate::models::Role::Assistant,
441 content: candidate.content,
442 name: None,
443 tool_calls: None,
444 tool_call_id: None,
445 },
446 finish_reason: Some("stop".to_string()),
447 logprobs: None,
448 })
449 })
450 .collect();
451
452 ChatResponse {
453 id: uuid::Uuid::new_v4().to_string(),
454 object: "chat.completion".to_string(),
455 created: chrono::Utc::now().timestamp() as u64,
456 model: "chat-bison".to_string(),
457 choices,
458 usage: None,
459 system_fingerprint: None,
460 }
461 }
462}
463
464#[derive(Debug, Serialize, Deserialize)]
466struct VertexAIRequest {
467 instances: Vec<VertexAIInstance>,
468 parameters: Option<VertexAIParameters>,
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472struct VertexAIInstance {
473 messages: Vec<VertexAIMessage>,
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477struct VertexAIMessage {
478 author: String,
479 content: String,
480}
481
482#[derive(Debug, Serialize, Deserialize)]
483struct VertexAIParameters {
484 temperature: f32,
485 max_output_tokens: i32,
486 top_p: Option<f32>,
487 top_k: Option<i32>,
488}
489
490#[derive(Debug, Serialize, Deserialize)]
491struct VertexAIResponse {
492 predictions: Vec<VertexAIPrediction>,
493}
494
495#[derive(Debug, Serialize, Deserialize)]
496struct VertexAIPrediction {
497 candidates: Vec<VertexAICandidate>,
498}
499
500#[derive(Debug, Serialize, Deserialize)]
501struct VertexAICandidate {
502 content: String,
503}
504
505#[derive(Debug, Serialize, Deserialize)]
506struct VertexAIEmbeddingRequest {
507 instances: Vec<VertexAIEmbeddingInstance>,
508}
509
510#[derive(Debug, Serialize, Deserialize)]
511struct VertexAIEmbeddingInstance {
512 content: String,
513}
514
515#[derive(Debug, Serialize, Deserialize)]
516struct VertexAIEmbeddingResponse {
517 predictions: Vec<VertexAIEmbeddingPrediction>,
518}
519
520#[derive(Debug, Serialize, Deserialize)]
521struct VertexAIEmbeddingPrediction {
522 embeddings: VertexAIEmbeddings,
523}
524
525#[derive(Debug, Serialize, Deserialize)]
526struct VertexAIEmbeddings {
527 values: Vec<f32>,
528}
529
530#[derive(Debug, Serialize, Deserialize)]
532struct VertexAIStreamRequest {
533 contents: Vec<VertexAIContent>,
534 generation_config: Option<VertexAIGenerationConfig>,
535}
536
537#[derive(Debug, Serialize, Deserialize)]
538struct VertexAIContent {
539 role: String,
540 parts: Vec<VertexAIPart>,
541}
542
543#[derive(Debug, Serialize, Deserialize)]
544struct VertexAIPart {
545 text: String,
546}
547
548#[derive(Debug, Serialize, Deserialize)]
549struct VertexAIGenerationConfig {
550 temperature: Option<f32>,
551 max_output_tokens: Option<i32>,
552 top_p: Option<f32>,
553 top_k: Option<i32>,
554}