1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::error::Error;

pub use crate::api::{API, APIResponse, IndexResults, IndexResultsBatch};

/// Embeddings definition
pub struct Embeddings {
    api: API
}

/// Embeddings implementation
impl Embeddings {
    /// Creates an Embeddings instance.
    ///
    /// # Arguments
    /// * `url` - base url of txtai API
    pub fn new(url: &str) -> Embeddings {
        Embeddings {
            api: API::new(url)
        }
    }

    /// Runs an Embeddings search. Returns Response. This method allows
    /// callers to customize the serialization of the response.
    /// 
    /// # Arguments
    /// * `query` - query text
    /// * `limit` - maximum results
    pub async fn query(&self, query: &str, limit: i32) -> APIResponse {
        // Query parameters
        let params = [("query", query), ("limit", &limit.to_string())];

        // Execute API call
        Ok(self.api.get("search", &params).await?)
    }

    /// Finds documents in the embeddings model most similar to the input query. Returns
    /// a list of {id: value, score: value} sorted by highest score, where id is the
    /// document id in the embeddings model.
    /// 
    /// # Arguments
    /// * `query` - query text
    /// * `limit` - maximum results
    pub async fn search(&self, query: &str, limit: i32) -> SearchResults {
        // Execute API call and map JSON
        Ok(self.query(query, limit).await?.json().await?)
    }

    /// Finds documents in the embeddings model most similar to the input queries. Returns
    /// a list of {id: value, score: value} sorted by highest score per query, where id is
    /// the document id in the embeddings model.
    ///
    /// # Arguments
    /// * `queries` - queries text
    /// * `limit` - maximum results
    pub async fn batchsearch(&self, queries: &str, limit: i32) -> SearchResultsBatch {
        // Post parameters
        let params = json!({"queries": queries, "limit": limit});

        // Execute API call
        Ok(self.api.post("batchsearch", &params).await?.json().await?)
    }

    /// Adds a batch of documents for indexing.
    /// 
    /// # Arguments
    /// * `documents` - list of {id: value, text: value}
    pub async fn add<T: Serialize>(&self, documents: &Vec<T>) -> APIResponse {
        // Execute API call
        Ok(self.api.post("add", &json!(documents)).await?)
    }

    /// Builds an embeddings index for previously batched documents. No further documents can be added
    /// after this call.
    pub async fn index(&self) -> APIResponse {
        // Execute API call
        Ok(self.api.get("index", &[]).await?)
    }

    /// Computes the similarity between query and list of text. Returns a list of
    /// {id: value, score: value} sorted by highest score, where id is the index
    /// in texts.
    ///
    /// # Arguments
    /// * `query` - query text
    /// * `texts` - list of text
    pub async fn similarity(&self, query: &str, texts: &Vec<&str>) -> IndexResults {
        // Post parameters
        let params = json!({"query": query, "texts": texts});

        // Execute API call
        Ok(self.api.post("similarity", &params).await?.json().await?)
    }

    /// Computes the similarity between list of queries and list of text. Returns a list
    /// of {id: value, score: value} sorted by highest score per query, where id is the
    /// index in texts.
    ///
    /// # Arguments
    /// * `queries` - queries text
    /// * `texts` - list of text
    pub async fn batchsimilarity(&self, queries: &Vec<&str>, texts: &Vec<&str>) -> IndexResultsBatch {
        // Post parameters
        let params = json!({"queries": queries, "texts": texts});

        // Execute API call
        Ok(self.api.post("batchsimilarity", &params).await?.json().await?)
    }

    /// Transforms text into an embeddings array.
    /// 
    /// # Arguments
    /// * `text` - input text
    pub async fn transform(&self, text: &str) -> Embedding {
        // Query parameters
        let params = [("text", text)];

        // Execute API call
        Ok(self.api.get("transform", &params).await?.json().await?)
    }

    /// Transforms list of text into embeddings arrays.
    ///
    /// # Arguments
    /// * `texts` - lists of text
    pub async fn batchtransform(&self, texts: &str) -> EmbeddingBatch {
        // Execute API call
        Ok(self.api.post("batchtransform", &json!(texts)).await?.json().await?)
    }
}

// Embeddings return types
pub type Embedding = Result<Vec<f32>, Box<dyn Error>>;
pub type EmbeddingBatch = Result<Vec<Vec<f32>>, Box<dyn Error>>;
pub type SearchResults = Result<Vec<SearchResult>, Box<dyn Error>>;
pub type SearchResultsBatch = Result<Vec<Vec<SearchResult>>, Box<dyn Error>>;

/// Input document
#[derive(Debug, Serialize)]
pub struct Document {
    pub id: String,
    pub text: String
}

// Search result
#[derive(Debug, Deserialize)]
pub struct SearchResult {
    pub id: String,
    pub score: f32
}