1use crate::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
2use crate::zoo_embedding_errors::ZooEmbeddingError;
3use async_trait::async_trait;
4
5use lazy_static::lazy_static;
6
7use reqwest::blocking::Client;
8
9use reqwest::Client as AsyncClient;
10use reqwest::ClientBuilder;
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14lazy_static! {
18 pub static ref DEFAULT_EMBEDDINGS_SERVER_URL: &'static str = "https://api.zoo.ngo/embeddings";
19 pub static ref DEFAULT_EMBEDDINGS_LOCAL_URL: &'static str = "http://localhost:11434/";
20}
21
22#[async_trait]
24pub trait EmbeddingGenerator: Sync + Send {
25 fn model_type(&self) -> EmbeddingModelType;
26 fn set_model_type(&mut self, model_type: EmbeddingModelType);
27 fn box_clone(&self) -> Box<dyn EmbeddingGenerator>;
28
29 fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
32
33 fn generate_embedding_default_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
36 self.generate_embedding_blocking(input_string)
37 }
38
39 fn generate_embeddings_blocking(&self, input_strings: &Vec<String>)
41 -> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
42
43 fn generate_embeddings_blocking_default(
45 &self,
46 input_strings: &Vec<String>,
47 ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
48 self.generate_embeddings_blocking(input_strings)
49 }
50
51 async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
54
55 async fn generate_embedding_default(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
58 self.generate_embedding(input_string).await
59 }
60 async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
64
65 async fn generate_embeddings_default(
67 &self,
68 input_strings: &Vec<String>,
69 ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
70 self.generate_embeddings(input_strings).await
71 }
72}
73
74#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
75
76pub struct RemoteEmbeddingGenerator {
77 pub model_type: EmbeddingModelType,
78 pub api_url: String,
79 pub api_key: Option<String>,
80}
81
82#[async_trait]
83impl EmbeddingGenerator for RemoteEmbeddingGenerator {
84 fn box_clone(&self) -> Box<dyn EmbeddingGenerator> {
86 Box::new(self.clone())
87 }
88
89 fn generate_embeddings_blocking(
93 &self,
94 input_strings: &Vec<String>,
95 ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
96 let input_strings: Vec<String> = input_strings
97 .iter()
98 .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
99 .collect();
100
101 match self.model_type {
102 EmbeddingModelType::OllamaTextEmbeddingsInference(_) => {
103 let mut embeddings = Vec::new();
104 for input_string in input_strings.iter() {
105 let embedding = self.generate_embedding_ollama_blocking(input_string)?;
106 embeddings.push(embedding);
107 }
108 Ok(embeddings)
109 }
110 }
111 }
112
113 fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
116 let input_strings = [input_string.to_string()];
117 let input_strings: Vec<String> = input_strings
118 .iter()
119 .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
120 .collect();
121
122 let results = self.generate_embeddings_blocking(&input_strings)?;
123 if results.is_empty() {
124 Err(ZooEmbeddingError::FailedEmbeddingGeneration(
125 "No results returned from the embedding generation".to_string(),
126 ))
127 } else {
128 Ok(results[0].clone())
129 }
130 }
131
132 async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
135 let input_strings: Vec<String> = input_strings
136 .iter()
137 .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
138 .collect();
139
140 match self.model_type.clone() {
141 EmbeddingModelType::OllamaTextEmbeddingsInference(model) => {
142 let mut embeddings = Vec::new();
143 for input_string in input_strings.iter() {
144 let embedding = self
145 .generate_embedding_ollama(input_string.clone(), model.to_string())
146 .await?;
147 embeddings.push(embedding);
148 }
149 Ok(embeddings)
150 }
151 }
152 }
153
154 async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
156 let input_strings = [input_string.to_string()];
157 let input_strings: Vec<String> = input_strings
158 .iter()
159 .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
160 .collect();
161
162 let results = self.generate_embeddings(&input_strings).await?;
163 if results.is_empty() {
164 Err(ZooEmbeddingError::FailedEmbeddingGeneration(
165 "No results returned from the embedding generation".to_string(),
166 ))
167 } else {
168 Ok(results[0].clone())
169 }
170 }
171
172 fn model_type(&self) -> EmbeddingModelType {
174 self.model_type.clone()
175 }
176
177 fn set_model_type(&mut self, model_type: EmbeddingModelType) {
179 self.model_type = model_type
180 }
181}
182
183impl RemoteEmbeddingGenerator {
184 pub fn new(model_type: EmbeddingModelType, api_url: &str, api_key: Option<String>) -> RemoteEmbeddingGenerator {
186 RemoteEmbeddingGenerator {
187 model_type,
188 api_url: api_url.to_string(),
189 api_key,
190 }
191 }
192
193 pub fn new_default() -> RemoteEmbeddingGenerator {
195 let model_architecture =
196 EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
197 RemoteEmbeddingGenerator {
198 model_type: model_architecture,
199 api_url: DEFAULT_EMBEDDINGS_SERVER_URL.to_string(),
200 api_key: None,
201 }
202 }
203 pub fn new_default_local() -> RemoteEmbeddingGenerator {
205 let model_architecture =
206 EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
207 RemoteEmbeddingGenerator {
208 model_type: model_architecture,
209 api_url: DEFAULT_EMBEDDINGS_LOCAL_URL.to_string(),
210 api_key: None,
211 }
212 }
213
214 fn tei_endpoint_url(&self) -> String {
217 if self.api_url.ends_with('/') {
218 format!("{}embed", self.api_url)
219 } else {
220 format!("{}/embed", self.api_url)
221 }
222 }
223
224 fn ollama_endpoint_url(&self) -> String {
227 if self.api_url.ends_with('/') {
228 format!("{}api/embeddings", self.api_url)
229 } else {
230 format!("{}/api/embeddings", self.api_url)
231 }
232 }
233
234 pub async fn generate_embedding_ollama(
237 &self,
238 input_string: String,
239 model: String,
240 ) -> Result<Vec<f32>, ZooEmbeddingError> {
241 let max_retries = 3;
242 let mut retry_count = 0;
243 let mut shortening_retry = 0;
244 let mut input_string = input_string.clone();
245
246 loop {
247 let request_body = OllamaEmbeddingsRequestBody {
249 model: model.clone(),
250 prompt: input_string.clone(),
251 };
252
253 let timeout = Duration::from_secs(60);
255 let client = ClientBuilder::new().timeout(timeout).build()?;
256
257 let mut request = client
259 .post(self.ollama_endpoint_url().to_string())
260 .header("Content-Type", "application/json")
261 .json(&request_body);
262
263 if let Some(api_key) = &self.api_key {
265 request = request.header("Authorization", format!("Bearer {}", api_key));
266 }
267
268 let response = request.send().await;
270
271 match response {
272 Ok(response) if response.status().is_success() => {
273 let embedding_response: Result<OllamaEmbeddingsResponse, _> =
274 response.json::<OllamaEmbeddingsResponse>().await;
275 match embedding_response {
276 Ok(embedding_response) => {
277 return Ok(embedding_response.embedding);
278 }
279 Err(err) => {
280 return Err(ZooEmbeddingError::RequestFailed(format!(
281 "Failed to deserialize response JSON: {}",
282 err
283 )));
284 }
285 }
286 }
287 Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
288 let reduction_step = if shortening_retry > 1 {
290 100 * shortening_retry
291 } else {
292 50
293 };
294 let shortened_max_size = input_string.len().saturating_sub(reduction_step).max(5);
295 input_string = input_string.chars().take(shortened_max_size).collect();
296
297 retry_count = 0;
298 shortening_retry += 1;
299 if shortening_retry > 10 {
300 return Err(ZooEmbeddingError::RequestFailed(format!(
301 "HTTP request failed after multiple recursive iterations shortening input. Status: {}",
302 response.status()
303 )));
304 }
305 continue;
306 }
307 Ok(response) => {
308 return Err(ZooEmbeddingError::RequestFailed(format!(
309 "HTTP request failed with status: {}",
310 response.status()
311 )));
312 }
313 Err(err) => {
314 if retry_count < max_retries {
315 retry_count += 1;
316 continue;
317 } else {
318 return Err(ZooEmbeddingError::RequestFailed(format!(
319 "HTTP request failed after {} retries: {}",
320 max_retries, err
321 )));
322 }
323 }
324 }
325 }
326 }
327
328 fn generate_embedding_ollama_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
330 let request_body = OllamaEmbeddingsRequestBody {
332 model: self.model_type.to_string(),
333 prompt: String::from(input_string),
334 };
335
336 let client = Client::new();
338
339 let mut request = client
341 .post(&format!("{}", self.ollama_endpoint_url()))
342 .header("Content-Type", "application/json")
343 .json(&request_body);
344
345 if let Some(api_key) = &self.api_key {
347 request = request.header("Authorization", format!("Bearer {}", api_key));
348 }
349
350 let response = request.send().map_err(|err| {
352 ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
354 })?;
355
356 if response.status().is_success() {
358 let embedding_response: OllamaEmbeddingsResponse = response.json().map_err(|err| {
359 ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
360 })?;
361 Ok(embedding_response.embedding)
362 } else {
363 Err(ZooEmbeddingError::RequestFailed(format!(
364 "HTTP request failed with status: {}",
365 response.status()
366 )))
367 }
368 }
369
370 pub async fn generate_embedding_tei(
372 &self,
373 input_strings: Vec<String>,
374 ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
375 let max_retries = 3;
376 let mut retry_count = 0;
377 let mut shortening_retry = 0;
378 let mut current_input_strings = input_strings.clone();
379
380 loop {
381 let request_body = EmbeddingArrayRequestBody {
383 inputs: current_input_strings.iter().map(|s| s.to_string()).collect(),
384 };
385
386 let timeout = Duration::from_secs(60);
388 let client = ClientBuilder::new().timeout(timeout).build()?;
389
390 let mut request = client
392 .post(self.tei_endpoint_url().to_string())
393 .header("Content-Type", "application/json")
394 .json(&request_body);
395
396 if let Some(api_key) = &self.api_key {
398 request = request.header("Authorization", format!("Bearer {}", api_key));
399 }
400
401 let response = request.send().await;
403
404 match response {
405 Ok(response) if response.status().is_success() => {
406 let embedding_response: Result<Vec<Vec<f32>>, _> = response.json::<Vec<Vec<f32>>>().await;
407 match embedding_response {
408 Ok(embedding_response) => {
409 return Ok(embedding_response);
410 }
411 Err(err) => {
412 return Err(ZooEmbeddingError::RequestFailed(format!(
413 "Failed to deserialize response JSON: {}",
414 err
415 )));
416 }
417 }
418 }
419 Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
420 let max_size = current_input_strings.iter().map(|s| s.len()).max().unwrap_or(0);
421 let reduction_step = if shortening_retry > 1 {
423 100 * shortening_retry
424 } else {
425 50
426 };
427 let shortened_max_size = max_size.saturating_sub(reduction_step).max(5);
428 current_input_strings = current_input_strings
429 .iter()
430 .map(|s| {
431 if s.len() > shortened_max_size {
432 s.chars().take(shortened_max_size).collect()
433 } else {
434 s.clone()
435 }
436 })
437 .collect();
438 retry_count = 0;
439 shortening_retry += 1;
440 if shortening_retry > 10 {
441 return Err(ZooEmbeddingError::RequestFailed(format!(
442 "HTTP request failed after multiple recursive iterations shortening input. Status: {}",
443 response.status()
444 )));
445 }
446 continue;
447 }
448 Ok(response) => {
449 return Err(ZooEmbeddingError::RequestFailed(format!(
450 "HTTP request failed with status: {}",
451 response.status()
452 )));
453 }
454 Err(err) => {
455 if retry_count < max_retries {
456 retry_count += 1;
457 continue;
458 } else {
459 return Err(ZooEmbeddingError::RequestFailed(format!(
460 "HTTP request failed after {} retries: {}",
461 max_retries, err
462 )));
463 }
464 }
465 }
466 }
467 }
468
469 pub async fn generate_embedding_open_ai(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
471 let request_body = EmbeddingRequestBody {
473 input: String::from(input_string),
474 model: self.model_type().to_string(),
475 };
476
477 let client = AsyncClient::new();
479
480 let mut request = client
482 .post(self.api_url.to_string())
483 .header("Content-Type", "application/json")
484 .json(&request_body);
485
486 if let Some(api_key) = &self.api_key {
488 request = request.header("Authorization", format!("Bearer {}", api_key));
489 }
490
491 let response = request.send().await.map_err(|err| {
493 ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
495 })?;
496
497 if response.status().is_success() {
499 let embedding_response: EmbeddingResponse = response.json().await.map_err(|err| {
502 ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
503 })?;
504
505 Ok(embedding_response.data[0].embedding.clone())
507 } else {
508 Err(ZooEmbeddingError::RequestFailed(format!(
510 "HTTP request failed with status: {}",
511 response.status()
512 )))
513 }
514 }
515}
516
517#[derive(Serialize)]
518#[allow(dead_code)]
519struct EmbeddingRequestBody {
520 input: String,
521 model: String,
522}
523
524#[derive(Deserialize)]
525#[allow(dead_code)]
526struct EmbeddingResponseData {
527 embedding: Vec<f32>,
528 index: usize,
529 object: String,
530}
531
532#[derive(Deserialize)]
533#[allow(dead_code)]
534struct EmbeddingResponse {
535 object: String,
536 model: String,
537 data: Vec<EmbeddingResponseData>,
538 usage: serde_json::Value, }
540
541#[derive(Serialize)]
542#[allow(dead_code)]
543struct EmbeddingArrayRequestBody {
544 inputs: Vec<String>,
545}
546
547#[derive(Debug, Serialize)]
548#[allow(dead_code)]
549struct OllamaEmbeddingsRequestBody {
550 model: String,
551 prompt: String,
552}
553
554#[derive(Deserialize)]
555#[allow(dead_code)]
556struct OllamaEmbeddingsResponse {
557 embedding: Vec<f32>,
558}