Skip to main content

zoo_embedding/
embedding_generator.rs

1use crate::model_type::{EmbeddingModelType, OllamaTextEmbeddingsInference};
2use crate::zoo_embedding_errors::ZooEmbeddingError;
3use async_trait::async_trait;
4
5use lazy_static::lazy_static;
6
7use reqwest::blocking::Client;
8
9use reqwest::Client as AsyncClient;
10use reqwest::ClientBuilder;
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14// TODO: remove duplicate methods
15// TODO: remove blocking / non-blocking methods
16
17lazy_static! {
18    pub static ref DEFAULT_EMBEDDINGS_SERVER_URL: &'static str = "https://api.zoo.ngo/embeddings";
19    pub static ref DEFAULT_EMBEDDINGS_LOCAL_URL: &'static str = "http://localhost:11434/";
20}
21
22/// A trait for types that can generate embeddings from text.
23#[async_trait]
24pub trait EmbeddingGenerator: Sync + Send {
25    fn model_type(&self) -> EmbeddingModelType;
26    fn set_model_type(&mut self, model_type: EmbeddingModelType);
27    fn box_clone(&self) -> Box<dyn EmbeddingGenerator>;
28
29    /// Generates an embedding from the given input string, and assigns the
30    /// provided id.
31    fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
32
33    /// Generate an Embedding for an input string, sets id to a default value
34    /// of empty string.
35    fn generate_embedding_default_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
36        self.generate_embedding_blocking(input_string)
37    }
38
39    /// Generates embeddings from the given list of input strings and ids.
40    fn generate_embeddings_blocking(&self, input_strings: &Vec<String>)
41        -> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
42
43    /// Generate Embeddings for a list of input strings, sets ids to default.
44    fn generate_embeddings_blocking_default(
45        &self,
46        input_strings: &Vec<String>,
47    ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
48        self.generate_embeddings_blocking(input_strings)
49    }
50
51    /// Generates an embedding from the given input string, and assigns the
52    /// provided id.
53    async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError>;
54
55    /// Generate an Embedding for an input string, sets id to a default value
56    /// of empty string.
57    async fn generate_embedding_default(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
58        self.generate_embedding(input_string).await
59    }
60    // ### TODO: remove all these duplicate methods
61
62    /// Generates embeddings from the given list of input strings and ids.
63    async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError>;
64
65    /// Generate Embeddings for a list of input strings, sets ids to default
66    async fn generate_embeddings_default(
67        &self,
68        input_strings: &Vec<String>,
69    ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
70        self.generate_embeddings(input_strings).await
71    }
72}
73
74#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
75
76pub struct RemoteEmbeddingGenerator {
77    pub model_type: EmbeddingModelType,
78    pub api_url: String,
79    pub api_key: Option<String>,
80}
81
82#[async_trait]
83impl EmbeddingGenerator for RemoteEmbeddingGenerator {
84    /// Clones self and wraps it in a Box
85    fn box_clone(&self) -> Box<dyn EmbeddingGenerator> {
86        Box::new(self.clone())
87    }
88
89    /// Generate Embeddings for an input list of strings by using the external API.
90    /// This method batch generates whenever possible to increase speed.
91    /// Note this method is blocking.
92    fn generate_embeddings_blocking(
93        &self,
94        input_strings: &Vec<String>,
95    ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
96        let input_strings: Vec<String> = input_strings
97            .iter()
98            .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
99            .collect();
100
101        match self.model_type {
102            EmbeddingModelType::OllamaTextEmbeddingsInference(_) => {
103                let mut embeddings = Vec::new();
104                for input_string in input_strings.iter() {
105                    let embedding = self.generate_embedding_ollama_blocking(input_string)?;
106                    embeddings.push(embedding);
107                }
108                Ok(embeddings)
109            }
110        }
111    }
112
113    /// Generate an Embedding for an input string by using the external API.
114    /// Note this method is blocking.
115    fn generate_embedding_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
116        let input_strings = [input_string.to_string()];
117        let input_strings: Vec<String> = input_strings
118            .iter()
119            .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
120            .collect();
121
122        let results = self.generate_embeddings_blocking(&input_strings)?;
123        if results.is_empty() {
124            Err(ZooEmbeddingError::FailedEmbeddingGeneration(
125                "No results returned from the embedding generation".to_string(),
126            ))
127        } else {
128            Ok(results[0].clone())
129        }
130    }
131
132    /// Generate an Embedding for an input string by using the external API.
133    /// This method batch generates whenever possible to increase speed.
134    async fn generate_embeddings(&self, input_strings: &Vec<String>) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
135        let input_strings: Vec<String> = input_strings
136            .iter()
137            .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
138            .collect();
139
140        match self.model_type.clone() {
141            EmbeddingModelType::OllamaTextEmbeddingsInference(model) => {
142                let mut embeddings = Vec::new();
143                for input_string in input_strings.iter() {
144                    let embedding = self
145                        .generate_embedding_ollama(input_string.clone(), model.to_string())
146                        .await?;
147                    embeddings.push(embedding);
148                }
149                Ok(embeddings)
150            }
151        }
152    }
153
154    /// Generate an Embedding for an input string by using the external API.
155    async fn generate_embedding(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
156        let input_strings = [input_string.to_string()];
157        let input_strings: Vec<String> = input_strings
158            .iter()
159            .map(|s| s.chars().take(self.model_type.max_input_token_count()).collect())
160            .collect();
161
162        let results = self.generate_embeddings(&input_strings).await?;
163        if results.is_empty() {
164            Err(ZooEmbeddingError::FailedEmbeddingGeneration(
165                "No results returned from the embedding generation".to_string(),
166            ))
167        } else {
168            Ok(results[0].clone())
169        }
170    }
171
172    /// Returns the EmbeddingModelType
173    fn model_type(&self) -> EmbeddingModelType {
174        self.model_type.clone()
175    }
176
177    /// Sets the EmbeddingModelType
178    fn set_model_type(&mut self, model_type: EmbeddingModelType) {
179        self.model_type = model_type
180    }
181}
182
183impl RemoteEmbeddingGenerator {
184    /// Create a RemoteEmbeddingGenerator
185    pub fn new(model_type: EmbeddingModelType, api_url: &str, api_key: Option<String>) -> RemoteEmbeddingGenerator {
186        RemoteEmbeddingGenerator {
187            model_type,
188            api_url: api_url.to_string(),
189            api_key,
190        }
191    }
192
193    /// Create a RemoteEmbeddingGenerator that uses the default model and server
194    pub fn new_default() -> RemoteEmbeddingGenerator {
195        let model_architecture =
196            EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
197        RemoteEmbeddingGenerator {
198            model_type: model_architecture,
199            api_url: DEFAULT_EMBEDDINGS_SERVER_URL.to_string(),
200            api_key: None,
201        }
202    }
203    /// Create a RemoteEmbeddingGenerator that uses the default model and server
204    pub fn new_default_local() -> RemoteEmbeddingGenerator {
205        let model_architecture =
206            EmbeddingModelType::OllamaTextEmbeddingsInference(OllamaTextEmbeddingsInference::SnowflakeArcticEmbedM);
207        RemoteEmbeddingGenerator {
208            model_type: model_architecture,
209            api_url: DEFAULT_EMBEDDINGS_LOCAL_URL.to_string(),
210            api_key: None,
211        }
212    }
213
214    /// String of the main endpoint url for generating embeddings via
215    /// Hugging face's Text Embedding Interface server
216    fn tei_endpoint_url(&self) -> String {
217        if self.api_url.ends_with('/') {
218            format!("{}embed", self.api_url)
219        } else {
220            format!("{}/embed", self.api_url)
221        }
222    }
223
224    /// String of the main endpoint url for generating embeddings via
225    /// Ollama Text Embedding Interface server
226    fn ollama_endpoint_url(&self) -> String {
227        if self.api_url.ends_with('/') {
228            format!("{}api/embeddings", self.api_url)
229        } else {
230            format!("{}/api/embeddings", self.api_url)
231        }
232    }
233
234    /// Generates embeddings using Hugging Face's Text Embedding Interface server
235    /// pub async fn generate_embedding_open_ai(&self, input_string: &str, id: &str) -> Result<Embedding, VRError> {
236    pub async fn generate_embedding_ollama(
237        &self,
238        input_string: String,
239        model: String,
240    ) -> Result<Vec<f32>, ZooEmbeddingError> {
241        let max_retries = 3;
242        let mut retry_count = 0;
243        let mut shortening_retry = 0;
244        let mut input_string = input_string.clone();
245
246        loop {
247            // Prepare the request body
248            let request_body = OllamaEmbeddingsRequestBody {
249                model: model.clone(),
250                prompt: input_string.clone(),
251            };
252
253            // Create the HTTP client with a custom timeout
254            let timeout = Duration::from_secs(60);
255            let client = ClientBuilder::new().timeout(timeout).build()?;
256
257            // Build the request
258            let mut request = client
259                .post(self.ollama_endpoint_url().to_string())
260                .header("Content-Type", "application/json")
261                .json(&request_body);
262
263            // Add the API key to the header if it's available
264            if let Some(api_key) = &self.api_key {
265                request = request.header("Authorization", format!("Bearer {}", api_key));
266            }
267
268            // Send the request
269            let response = request.send().await;
270
271            match response {
272                Ok(response) if response.status().is_success() => {
273                    let embedding_response: Result<OllamaEmbeddingsResponse, _> =
274                        response.json::<OllamaEmbeddingsResponse>().await;
275                    match embedding_response {
276                        Ok(embedding_response) => {
277                            return Ok(embedding_response.embedding);
278                        }
279                        Err(err) => {
280                            return Err(ZooEmbeddingError::RequestFailed(format!(
281                                "Failed to deserialize response JSON: {}",
282                                err
283                            )));
284                        }
285                    }
286                }
287                Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
288                    // Calculate the maximum size allowed based on the number of retries
289                    let reduction_step = if shortening_retry > 1 {
290                        100 * shortening_retry
291                    } else {
292                        50
293                    };
294                    let shortened_max_size = input_string.len().saturating_sub(reduction_step).max(5);
295                    input_string = input_string.chars().take(shortened_max_size).collect();
296
297                    retry_count = 0;
298                    shortening_retry += 1;
299                    if shortening_retry > 10 {
300                        return Err(ZooEmbeddingError::RequestFailed(format!(
301                            "HTTP request failed after multiple recursive iterations shortening input. Status: {}",
302                            response.status()
303                        )));
304                    }
305                    continue;
306                }
307                Ok(response) => {
308                    return Err(ZooEmbeddingError::RequestFailed(format!(
309                        "HTTP request failed with status: {}",
310                        response.status()
311                    )));
312                }
313                Err(err) => {
314                    if retry_count < max_retries {
315                        retry_count += 1;
316                        continue;
317                    } else {
318                        return Err(ZooEmbeddingError::RequestFailed(format!(
319                            "HTTP request failed after {} retries: {}",
320                            max_retries, err
321                        )));
322                    }
323                }
324            }
325        }
326    }
327
328    /// Generate an Embedding for an input string by using the external Ollama API.
329    fn generate_embedding_ollama_blocking(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
330        // Prepare the request body
331        let request_body = OllamaEmbeddingsRequestBody {
332            model: self.model_type.to_string(),
333            prompt: String::from(input_string),
334        };
335
336        // Create the HTTP client
337        let client = Client::new();
338
339        // Build the request
340        let mut request = client
341            .post(&format!("{}", self.ollama_endpoint_url()))
342            .header("Content-Type", "application/json")
343            .json(&request_body);
344
345        // Add the API key to the header if it's available
346        if let Some(api_key) = &self.api_key {
347            request = request.header("Authorization", format!("Bearer {}", api_key));
348        }
349
350        // Send the request and check for errors
351        let response = request.send().map_err(|err| {
352            // Handle any HTTP client errors here (e.g., request creation failure)
353            ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
354        })?;
355
356        // Check if the response is successful
357        if response.status().is_success() {
358            let embedding_response: OllamaEmbeddingsResponse = response.json().map_err(|err| {
359                ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
360            })?;
361            Ok(embedding_response.embedding)
362        } else {
363            Err(ZooEmbeddingError::RequestFailed(format!(
364                "HTTP request failed with status: {}",
365                response.status()
366            )))
367        }
368    }
369
370    /// Generates embeddings using Hugging Face's Text Embedding Interface server
371    pub async fn generate_embedding_tei(
372        &self,
373        input_strings: Vec<String>,
374    ) -> Result<Vec<Vec<f32>>, ZooEmbeddingError> {
375        let max_retries = 3;
376        let mut retry_count = 0;
377        let mut shortening_retry = 0;
378        let mut current_input_strings = input_strings.clone();
379
380        loop {
381            // Prepare the request body
382            let request_body = EmbeddingArrayRequestBody {
383                inputs: current_input_strings.iter().map(|s| s.to_string()).collect(),
384            };
385
386            // Create the HTTP client with a custom timeout
387            let timeout = Duration::from_secs(60);
388            let client = ClientBuilder::new().timeout(timeout).build()?;
389
390            // Build the request
391            let mut request = client
392                .post(self.tei_endpoint_url().to_string())
393                .header("Content-Type", "application/json")
394                .json(&request_body);
395
396            // Add the API key to the header if it's available
397            if let Some(api_key) = &self.api_key {
398                request = request.header("Authorization", format!("Bearer {}", api_key));
399            }
400
401            // Send the request
402            let response = request.send().await;
403
404            match response {
405                Ok(response) if response.status().is_success() => {
406                    let embedding_response: Result<Vec<Vec<f32>>, _> = response.json::<Vec<Vec<f32>>>().await;
407                    match embedding_response {
408                        Ok(embedding_response) => {
409                            return Ok(embedding_response);
410                        }
411                        Err(err) => {
412                            return Err(ZooEmbeddingError::RequestFailed(format!(
413                                "Failed to deserialize response JSON: {}",
414                                err
415                            )));
416                        }
417                    }
418                }
419                Ok(response) if response.status() == reqwest::StatusCode::PAYLOAD_TOO_LARGE => {
420                    let max_size = current_input_strings.iter().map(|s| s.len()).max().unwrap_or(0);
421                    // Increase the number of characters removed based on the number of retries
422                    let reduction_step = if shortening_retry > 1 {
423                        100 * shortening_retry
424                    } else {
425                        50
426                    };
427                    let shortened_max_size = max_size.saturating_sub(reduction_step).max(5);
428                    current_input_strings = current_input_strings
429                        .iter()
430                        .map(|s| {
431                            if s.len() > shortened_max_size {
432                                s.chars().take(shortened_max_size).collect()
433                            } else {
434                                s.clone()
435                            }
436                        })
437                        .collect();
438                    retry_count = 0;
439                    shortening_retry += 1;
440                    if shortening_retry > 10 {
441                        return Err(ZooEmbeddingError::RequestFailed(format!(
442                            "HTTP request failed after multiple recursive iterations shortening input. Status: {}",
443                            response.status()
444                        )));
445                    }
446                    continue;
447                }
448                Ok(response) => {
449                    return Err(ZooEmbeddingError::RequestFailed(format!(
450                        "HTTP request failed with status: {}",
451                        response.status()
452                    )));
453                }
454                Err(err) => {
455                    if retry_count < max_retries {
456                        retry_count += 1;
457                        continue;
458                    } else {
459                        return Err(ZooEmbeddingError::RequestFailed(format!(
460                            "HTTP request failed after {} retries: {}",
461                            max_retries, err
462                        )));
463                    }
464                }
465            }
466        }
467    }
468
469    /// Generate an Embedding for an input string by using the external OpenAI-matching API.
470    pub async fn generate_embedding_open_ai(&self, input_string: &str) -> Result<Vec<f32>, ZooEmbeddingError> {
471        // Prepare the request body
472        let request_body = EmbeddingRequestBody {
473            input: String::from(input_string),
474            model: self.model_type().to_string(),
475        };
476
477        // Create the HTTP client
478        let client = AsyncClient::new();
479
480        // Build the request
481        let mut request = client
482            .post(self.api_url.to_string())
483            .header("Content-Type", "application/json")
484            .json(&request_body);
485
486        // Add the API key to the header if it's available
487        if let Some(api_key) = &self.api_key {
488            request = request.header("Authorization", format!("Bearer {}", api_key));
489        }
490
491        // Send the request and check for errors
492        let response = request.send().await.map_err(|err| {
493            // Handle any HTTP client errors here (e.g., request creation failure)
494            ZooEmbeddingError::RequestFailed(format!("HTTP request failed: {}", err))
495        })?;
496
497        // Check if the response is successful
498        if response.status().is_success() {
499            // Deserialize the response JSON into a struct (assuming you have an
500            // EmbeddingResponse struct)
501            let embedding_response: EmbeddingResponse = response.json().await.map_err(|err| {
502                ZooEmbeddingError::RequestFailed(format!("Failed to deserialize response JSON: {}", err))
503            })?;
504
505            // Use the response to create an Embedding instance
506            Ok(embedding_response.data[0].embedding.clone())
507        } else {
508            // Handle non-successful HTTP responses (e.g., server error)
509            Err(ZooEmbeddingError::RequestFailed(format!(
510                "HTTP request failed with status: {}",
511                response.status()
512            )))
513        }
514    }
515}
516
517#[derive(Serialize)]
518#[allow(dead_code)]
519struct EmbeddingRequestBody {
520    input: String,
521    model: String,
522}
523
524#[derive(Deserialize)]
525#[allow(dead_code)]
526struct EmbeddingResponseData {
527    embedding: Vec<f32>,
528    index: usize,
529    object: String,
530}
531
532#[derive(Deserialize)]
533#[allow(dead_code)]
534struct EmbeddingResponse {
535    object: String,
536    model: String,
537    data: Vec<EmbeddingResponseData>,
538    usage: serde_json::Value, // or define a separate struct for this if you need to use these values
539}
540
541#[derive(Serialize)]
542#[allow(dead_code)]
543struct EmbeddingArrayRequestBody {
544    inputs: Vec<String>,
545}
546
547#[derive(Debug, Serialize)]
548#[allow(dead_code)]
549struct OllamaEmbeddingsRequestBody {
550    model: String,
551    prompt: String,
552}
553
554#[derive(Deserialize)]
555#[allow(dead_code)]
556struct OllamaEmbeddingsResponse {
557    embedding: Vec<f32>,
558}